/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.operator.aggregation.InternalAggregationFunction;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

public class PartialAggregationPushDown
implements PlanOptimizer {
    private final FunctionRegistry functionRegistry;

    public PartialAggregationPushDown(Metadata metadata) {
        Objects.requireNonNull(metadata, "metadata is null");
        this.functionRegistry = metadata.getFunctionRegistry();
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        return SimplePlanRewriter.rewriteWith(new Rewriter(symbolAllocator, idAllocator), plan, null);
    }

    private static class AggregationWithLayout {
        private final AggregationNode aggregationNode;
        private final List<Symbol> layout;

        public AggregationWithLayout(AggregationNode aggregationNode, List<Symbol> layout) {
            this.aggregationNode = Objects.requireNonNull(aggregationNode, "aggregationNode is null");
            this.layout = ImmutableList.copyOf((Collection)Objects.requireNonNull(layout, "layout is null"));
        }

        public AggregationNode getAggregationNode() {
            return this.aggregationNode;
        }

        public List<Symbol> getLayout() {
            return this.layout;
        }
    }

    private class Rewriter
    extends SimplePlanRewriter<AggregationNode> {
        private final SymbolAllocator allocator;
        private final PlanNodeIdAllocator idAllocator;

        public Rewriter(SymbolAllocator allocator, PlanNodeIdAllocator idAllocator) {
            this.allocator = Objects.requireNonNull(allocator, "allocator is null");
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<AggregationNode> context) {
            boolean decomposable = node.getFunctions().values().stream().map(PartialAggregationPushDown.this.functionRegistry::getAggregateFunctionImplementation).allMatch(InternalAggregationFunction::isDecomposable);
            Preconditions.checkState((node.getStep() == AggregationNode.Step.SINGLE ? 1 : 0) != 0, (String)"aggregation should be SINGLE, but it is %s", (Object[])new Object[]{node.getStep()});
            Preconditions.checkState((context.get() == null ? 1 : 0) != 0, (String)"context is not null: %s", (Object[])new Object[]{context});
            if (!decomposable || !this.allowPushThrough(node.getSource())) {
                return context.defaultRewrite(node);
            }
            Map<Symbol, Symbol> masks = node.getMasks();
            HashMap<Symbol, FunctionCall> finalCalls = new HashMap<Symbol, FunctionCall>();
            HashMap<Symbol, FunctionCall> intermediateCalls = new HashMap<Symbol, FunctionCall>();
            HashMap<Symbol, Signature> intermediateFunctions = new HashMap<Symbol, Signature>();
            HashMap<Symbol, Symbol> intermediateMask = new HashMap<Symbol, Symbol>();
            for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) {
                Signature signature = node.getFunctions().get(entry.getKey());
                Symbol intermediateSymbol = this.generateIntermediateSymbol(signature);
                intermediateCalls.put(intermediateSymbol, entry.getValue());
                intermediateFunctions.put(intermediateSymbol, signature);
                if (masks.containsKey(entry.getKey())) {
                    intermediateMask.put(intermediateSymbol, masks.get(entry.getKey()));
                }
                finalCalls.put(entry.getKey(), new FunctionCall(QualifiedName.of((String)signature.getName()), (List)ImmutableList.of((Object)intermediateSymbol.toSymbolReference())));
            }
            AggregationNode partial = new AggregationNode(this.idAllocator.getNextId(), node.getSource(), intermediateCalls, intermediateFunctions, intermediateMask, node.getGroupingSets(), AggregationNode.Step.PARTIAL, node.getSampleWeight(), node.getConfidence(), node.getHashSymbol(), node.getGroupIdSymbol());
            return new AggregationNode(node.getId(), context.rewrite(node.getSource(), partial), finalCalls, node.getFunctions(), (Map<Symbol, Symbol>)ImmutableMap.of(), node.getGroupingSets(), AggregationNode.Step.FINAL, Optional.empty(), node.getConfidence(), node.getHashSymbol(), node.getGroupIdSymbol());
        }

        @Override
        public PlanNode visitExchange(ExchangeNode node, SimplePlanRewriter.RewriteContext<AggregationNode> context) {
            AggregationNode partial = context.get();
            if (partial == null) {
                return context.defaultRewrite(node);
            }
            ArrayList<PlanNode> newChildren = new ArrayList<PlanNode>();
            ArrayList<List<Symbol>> inputs = new ArrayList<List<Symbol>>();
            boolean allowPushThroughChildren = node.getSources().stream().allMatch(this::allowPushThrough);
            for (int i = 0; i < node.getSources().size(); ++i) {
                PlanNode currentSource = node.getSources().get(i);
                Map<Symbol, Symbol> exchangeMap = this.buildExchangeMap(node.getOutputSymbols(), node.getInputs().get(i));
                AggregationWithLayout childPartial = this.generateNewPartial(partial, currentSource, exchangeMap);
                inputs.add(childPartial.getLayout());
                PlanNode child = allowPushThroughChildren ? context.rewrite(currentSource, childPartial.getAggregationNode()) : context.defaultRewrite(childPartial.getAggregationNode());
                newChildren.add(child);
            }
            PartitioningScheme partitioningScheme = new PartitioningScheme(node.getPartitioningScheme().getPartitioning(), partial.getOutputSymbols(), partial.getHashSymbol());
            return new ExchangeNode(node.getId(), node.getType(), node.getScope(), partitioningScheme, newChildren, inputs);
        }

        private boolean allowPushThrough(PlanNode node) {
            if (node instanceof ExchangeNode) {
                ExchangeNode exchangeNode = (ExchangeNode)node;
                return exchangeNode.getType() != ExchangeNode.Type.REPLICATE && !exchangeNode.getPartitioningScheme().isReplicateNulls();
            }
            return false;
        }

        private Symbol generateIntermediateSymbol(Signature signature) {
            InternalAggregationFunction function = PartialAggregationPushDown.this.functionRegistry.getAggregateFunctionImplementation(signature);
            return this.allocator.newSymbol(signature.getName(), function.getIntermediateType());
        }

        private Map<Symbol, Symbol> buildExchangeMap(List<Symbol> exchangeOutput, List<Symbol> sourceOutput) {
            Preconditions.checkState((exchangeOutput.size() == sourceOutput.size() ? 1 : 0) != 0, (Object)"exchange output length doesn't match source output length");
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (int i = 0; i < exchangeOutput.size(); ++i) {
                builder.put((Object)exchangeOutput.get(i), (Object)sourceOutput.get(i));
            }
            return builder.build();
        }

        private List<Expression> replaceArguments(List<Expression> arguments, Map<Symbol, Symbol> exchangeMap) {
            HashMap<SymbolReference, SymbolReference> symbolReferenceSymbolMap = new HashMap<SymbolReference, SymbolReference>();
            for (Map.Entry<Symbol, Symbol> entry : exchangeMap.entrySet()) {
                symbolReferenceSymbolMap.put(entry.getKey().toSymbolReference(), entry.getValue().toSymbolReference());
            }
            return arguments.stream().map(expression -> {
                if (symbolReferenceSymbolMap.containsKey(expression)) {
                    return (Expression)symbolReferenceSymbolMap.get(expression);
                }
                return expression;
            }).collect(Collectors.toList());
        }

        private AggregationWithLayout generateNewPartial(AggregationNode node, PlanNode source, Map<Symbol, Symbol> exchangeMap) {
            Preconditions.checkState((!node.getHashSymbol().isPresent() ? 1 : 0) != 0, (Object)"PartialAggregationPushDown optimizer must run before HashGenerationOptimizer");
            HashMap<Symbol, Symbol> layoutMap = new HashMap<Symbol, Symbol>();
            HashMap<Symbol, FunctionCall> functionCallMap = new HashMap<Symbol, FunctionCall>();
            HashMap<Symbol, Signature> signatureMap = new HashMap<Symbol, Signature>();
            HashMap<Symbol, Symbol> mask = new HashMap<Symbol, Symbol>();
            for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) {
                Signature signature = node.getFunctions().get(entry.getKey());
                Symbol symbol = this.generateIntermediateSymbol(signature);
                signatureMap.put(symbol, node.getFunctions().get(entry.getKey()));
                List<Expression> arguments = this.replaceArguments(entry.getValue().getArguments(), exchangeMap);
                functionCallMap.put(symbol, new FunctionCall(entry.getValue().getName(), arguments));
                if (node.getMasks().containsKey(entry.getKey())) {
                    mask.put(symbol, exchangeMap.get(node.getMasks().get(entry.getKey())));
                }
                layoutMap.put(entry.getKey(), symbol);
            }
            for (Symbol symbol : node.getGroupingKeys()) {
                Symbol symbol2 = exchangeMap.get(symbol);
                layoutMap.put(symbol, symbol2);
            }
            ImmutableList.Builder groupingSets = ImmutableList.builder();
            for (List<Symbol> list : node.getGroupingSets()) {
                ImmutableList.Builder symbolList = ImmutableList.builder();
                for (Symbol symbol : list) {
                    Symbol translated = exchangeMap.get(symbol);
                    symbolList.add((Object)translated);
                }
                groupingSets.add((Object)symbolList.build());
            }
            AggregationNode aggregationNode = new AggregationNode(this.idAllocator.getNextId(), source, functionCallMap, signatureMap, mask, (List<List<Symbol>>)groupingSets.build(), AggregationNode.Step.PARTIAL, node.getSampleWeight(), node.getConfidence(), node.getHashSymbol(), node.getGroupIdSymbol().map(exchangeMap::get));
            List<Symbol> list = node.getOutputSymbols().stream().map(layoutMap::get).collect(Collectors.toList());
            return new AggregationWithLayout(aggregationNode, list);
        }
    }
}

