/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.operator.aggregation.InternalAggregationFunction;
import com.facebook.presto.sql.planner.Partitioning;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.SymbolMapper;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

public class PushPartialAggregationThroughExchange
implements Rule<AggregationNode> {
    private final FunctionRegistry functionRegistry;
    private static final Capture<ExchangeNode> EXCHANGE_NODE = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.source().matching(Patterns.exchange().matching(node -> !node.getOrderingScheme().isPresent()).capturedAs(EXCHANGE_NODE)));

    public PushPartialAggregationThroughExchange(FunctionRegistry functionRegistry) {
        this.functionRegistry = Objects.requireNonNull(functionRegistry, "functionRegistry is null");
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        ExchangeNode exchangeNode = (ExchangeNode)captures.get(EXCHANGE_NODE);
        boolean decomposable = aggregationNode.isDecomposable(this.functionRegistry);
        if (aggregationNode.getStep().equals((Object)AggregationNode.Step.SINGLE) && aggregationNode.hasEmptyGroupingSet() && aggregationNode.hasNonEmptyGroupingSet() && exchangeNode.getType() == ExchangeNode.Type.REPARTITION) {
            Preconditions.checkState((boolean)decomposable, (Object)"Distributed aggregation with empty grouping set requires partial but functions are not decomposable");
            return Rule.Result.ofPlanNode(this.split(aggregationNode, context));
        }
        if (!decomposable || !SystemSessionProperties.preferPartialAggregation(context.getSession())) {
            return Rule.Result.empty();
        }
        if (exchangeNode.getType() != ExchangeNode.Type.GATHER && exchangeNode.getType() != ExchangeNode.Type.REPARTITION || exchangeNode.getPartitioningScheme().isReplicateNullsAndAny()) {
            return Rule.Result.empty();
        }
        if (exchangeNode.getType() == ExchangeNode.Type.REPARTITION) {
            List partitioningColumns = exchangeNode.getPartitioningScheme().getPartitioning().getArguments().stream().filter(Partitioning.ArgumentBinding::isVariable).map(Partitioning.ArgumentBinding::getColumn).collect(Collectors.toList());
            if (!aggregationNode.getGroupingKeys().containsAll(partitioningColumns)) {
                return Rule.Result.empty();
            }
        }
        if (aggregationNode.getHashSymbol().isPresent() || exchangeNode.getPartitioningScheme().getHashColumn().isPresent()) {
            return Rule.Result.empty();
        }
        switch (aggregationNode.getStep()) {
            case SINGLE: {
                return Rule.Result.ofPlanNode(this.split(aggregationNode, context));
            }
            case PARTIAL: {
                return Rule.Result.ofPlanNode(this.pushPartial(aggregationNode, exchangeNode, context));
            }
        }
        return Rule.Result.empty();
    }

    private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, Rule.Context context) {
        ArrayList<PlanNode> partials = new ArrayList<PlanNode>();
        for (int i = 0; i < exchange.getSources().size(); ++i) {
            PlanNode source = exchange.getSources().get(i);
            SymbolMapper.Builder mappingsBuilder = SymbolMapper.builder();
            for (int outputIndex = 0; outputIndex < exchange.getOutputSymbols().size(); ++outputIndex) {
                Symbol input;
                Symbol output = exchange.getOutputSymbols().get(outputIndex);
                if (output.equals(input = exchange.getInputs().get(i).get(outputIndex))) continue;
                mappingsBuilder.put(output, input);
            }
            SymbolMapper symbolMapper = mappingsBuilder.build();
            AggregationNode mappedPartial = symbolMapper.map(aggregation, source, context.getIdAllocator());
            Assignments.Builder assignments = Assignments.builder();
            for (Symbol output : aggregation.getOutputSymbols()) {
                Symbol input = symbolMapper.map(output);
                assignments.put(output, (Expression)input.toSymbolReference());
            }
            partials.add(new ProjectNode(context.getIdAllocator().getNextId(), mappedPartial, assignments.build()));
        }
        for (PlanNode node : partials) {
            Verify.verify((boolean)aggregation.getOutputSymbols().equals(node.getOutputSymbols()));
        }
        PartitioningScheme partitioning = new PartitioningScheme(exchange.getPartitioningScheme().getPartitioning(), aggregation.getOutputSymbols(), exchange.getPartitioningScheme().getHashColumn(), exchange.getPartitioningScheme().isReplicateNullsAndAny(), exchange.getPartitioningScheme().getBucketToPartition());
        return new ExchangeNode(context.getIdAllocator().getNextId(), exchange.getType(), exchange.getScope(), partitioning, partials, (List<List<Symbol>>)ImmutableList.copyOf(Collections.nCopies(partials.size(), aggregation.getOutputSymbols())), Optional.empty());
    }

    private PlanNode split(AggregationNode node, Rule.Context context) {
        HashMap<Symbol, AggregationNode.Aggregation> intermediateAggregation = new HashMap<Symbol, AggregationNode.Aggregation>();
        HashMap<Symbol, AggregationNode.Aggregation> finalAggregation = new HashMap<Symbol, AggregationNode.Aggregation>();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
            AggregationNode.Aggregation originalAggregation = entry.getValue();
            Signature signature = originalAggregation.getSignature();
            InternalAggregationFunction function = this.functionRegistry.getAggregateFunctionImplementation(signature);
            Symbol intermediateSymbol = context.getSymbolAllocator().newSymbol(signature.getName(), function.getIntermediateType());
            Preconditions.checkState((!originalAggregation.getCall().getOrderBy().isPresent() ? 1 : 0) != 0, (Object)"Aggregate with ORDER BY does not support partial aggregation");
            intermediateAggregation.put(intermediateSymbol, new AggregationNode.Aggregation(originalAggregation.getCall(), signature, originalAggregation.getMask()));
            finalAggregation.put(entry.getKey(), new AggregationNode.Aggregation(new FunctionCall(QualifiedName.of((String)signature.getName()), (List)ImmutableList.builder().add((Object)intermediateSymbol.toSymbolReference()).addAll((Iterable)originalAggregation.getCall().getArguments().stream().filter(LambdaExpression.class::isInstance).collect(ImmutableList.toImmutableList())).build()), signature, Optional.empty()));
        }
        AggregationNode partial = new AggregationNode(context.getIdAllocator().getNextId(), node.getSource(), intermediateAggregation, node.getGroupingSets(), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.PARTIAL, node.getHashSymbol(), node.getGroupIdSymbol());
        return new AggregationNode(node.getId(), partial, finalAggregation, node.getGroupingSets(), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.FINAL, node.getHashSymbol(), node.getGroupIdSymbol());
    }
}

