/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.operator.aggregation.InternalAggregationFunction;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.Partitioning;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.optimizations.SymbolMapper;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

public class PartialAggregationPushDown
implements PlanOptimizer {
    private final FunctionRegistry functionRegistry;

    public PartialAggregationPushDown(FunctionRegistry registry) {
        Objects.requireNonNull(registry, "registry is null");
        this.functionRegistry = registry;
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        return SimplePlanRewriter.rewriteWith(new Rewriter(symbolAllocator, idAllocator), plan, null);
    }

    private class Rewriter
    extends SimplePlanRewriter<Void> {
        private final SymbolAllocator allocator;
        private final PlanNodeIdAllocator idAllocator;

        public Rewriter(SymbolAllocator allocator, PlanNodeIdAllocator idAllocator) {
            this.allocator = Objects.requireNonNull(allocator, "allocator is null");
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            PlanNode child = node.getSource();
            if (!(child instanceof ExchangeNode)) {
                return context.defaultRewrite(node);
            }
            ExchangeNode exchange = (ExchangeNode)child;
            if (exchange.getType() != ExchangeNode.Type.GATHER && exchange.getType() != ExchangeNode.Type.REPARTITION || exchange.getPartitioningScheme().isReplicateNulls()) {
                return context.defaultRewrite(node);
            }
            if (exchange.getType() == ExchangeNode.Type.REPARTITION) {
                List partitioningColumns = exchange.getPartitioningScheme().getPartitioning().getArguments().stream().filter(Partitioning.ArgumentBinding::isVariable).map(Partitioning.ArgumentBinding::getColumn).collect(Collectors.toList());
                if (!node.getGroupingKeys().containsAll(partitioningColumns)) {
                    return context.defaultRewrite(node);
                }
            }
            if (node.getHashSymbol().isPresent() || exchange.getPartitioningScheme().getHashColumn().isPresent()) {
                return context.defaultRewrite(node);
            }
            boolean decomposable = node.getFunctions().values().stream().map(PartialAggregationPushDown.this.functionRegistry::getAggregateFunctionImplementation).allMatch(InternalAggregationFunction::isDecomposable);
            if (!decomposable) {
                return context.defaultRewrite(node);
            }
            switch (node.getStep()) {
                case SINGLE: {
                    return context.rewrite(this.split(node));
                }
                case PARTIAL: {
                    return context.rewrite(this.pushPartial(node, exchange));
                }
            }
            return context.defaultRewrite(node);
        }

        private PlanNode pushPartial(AggregationNode partial, ExchangeNode exchange) {
            ArrayList<PlanNode> partials = new ArrayList<PlanNode>();
            for (int i = 0; i < exchange.getSources().size(); ++i) {
                PlanNode source = exchange.getSources().get(i);
                SymbolMapper.Builder mappingsBuilder = SymbolMapper.builder();
                for (int outputIndex = 0; outputIndex < exchange.getOutputSymbols().size(); ++outputIndex) {
                    Symbol input;
                    Symbol output = exchange.getOutputSymbols().get(outputIndex);
                    if (output.equals(input = exchange.getInputs().get(i).get(outputIndex))) continue;
                    mappingsBuilder.put(output, input);
                }
                SymbolMapper symbolMapper = mappingsBuilder.build();
                AggregationNode mappedPartial = symbolMapper.map(partial, source, this.idAllocator);
                Assignments.Builder assignments = Assignments.builder();
                for (Symbol output : partial.getOutputSymbols()) {
                    Symbol input = symbolMapper.map(output);
                    assignments.put(output, (Expression)input.toSymbolReference());
                }
                partials.add(new ProjectNode(this.idAllocator.getNextId(), mappedPartial, assignments.build()));
            }
            for (PlanNode node : partials) {
                Verify.verify((boolean)partial.getOutputSymbols().equals(node.getOutputSymbols()));
            }
            PartitioningScheme partitioning = new PartitioningScheme(exchange.getPartitioningScheme().getPartitioning(), partial.getOutputSymbols(), exchange.getPartitioningScheme().getHashColumn(), exchange.getPartitioningScheme().isReplicateNulls(), exchange.getPartitioningScheme().getBucketToPartition());
            return new ExchangeNode(this.idAllocator.getNextId(), exchange.getType(), exchange.getScope(), partitioning, partials, (List<List<Symbol>>)ImmutableList.copyOf(Collections.nCopies(partials.size(), partial.getOutputSymbols())));
        }

        private PlanNode split(AggregationNode node) {
            Map<Symbol, Symbol> masks = node.getMasks();
            HashMap<Symbol, FunctionCall> finalCalls = new HashMap<Symbol, FunctionCall>();
            HashMap<Symbol, FunctionCall> intermediateCalls = new HashMap<Symbol, FunctionCall>();
            HashMap<Symbol, Signature> intermediateFunctions = new HashMap<Symbol, Signature>();
            HashMap<Symbol, Symbol> intermediateMask = new HashMap<Symbol, Symbol>();
            for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) {
                Signature signature = node.getFunctions().get(entry.getKey());
                InternalAggregationFunction function = PartialAggregationPushDown.this.functionRegistry.getAggregateFunctionImplementation(signature);
                Symbol intermediateSymbol = this.allocator.newSymbol(signature.getName(), function.getIntermediateType());
                intermediateCalls.put(intermediateSymbol, entry.getValue());
                intermediateFunctions.put(intermediateSymbol, signature);
                if (masks.containsKey(entry.getKey())) {
                    intermediateMask.put(intermediateSymbol, masks.get(entry.getKey()));
                }
                finalCalls.put(entry.getKey(), new FunctionCall(QualifiedName.of((String)signature.getName()), (List)ImmutableList.of((Object)intermediateSymbol.toSymbolReference())));
            }
            AggregationNode partial = new AggregationNode(this.idAllocator.getNextId(), node.getSource(), intermediateCalls, intermediateFunctions, intermediateMask, node.getGroupingSets(), AggregationNode.Step.PARTIAL, node.getHashSymbol(), node.getGroupIdSymbol());
            return new AggregationNode(node.getId(), partial, finalCalls, node.getFunctions(), (Map<Symbol, Symbol>)ImmutableMap.of(), node.getGroupingSets(), AggregationNode.Step.FINAL, node.getHashSymbol(), node.getGroupIdSymbol());
        }
    }
}

