package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.planner.Partitioning;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolsExtractor;
import com.facebook.presto.sql.planner.SystemPartitioningHandle;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.class */
public class AddIntermediateAggregations implements Rule {
    private static final Pattern PATTERN = Pattern.typeOf(AggregationNode.class);

    @Override // com.facebook.presto.sql.planner.iterative.Rule, com.facebook.presto.matching.Matchable
    public Pattern getPattern() {
        return PATTERN;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isEnableIntermediateAggregations(session);
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Optional<PlanNode> apply(PlanNode planNode, Rule.Context context) {
        Lookup lookup = context.getLookup();
        PlanNodeIdAllocator idAllocator = context.getIdAllocator();
        Session session = context.getSession();
        if (!(planNode instanceof AggregationNode)) {
            return Optional.empty();
        }
        AggregationNode aggregationNode = (AggregationNode) planNode;
        if (aggregationNode.getStep() != AggregationNode.Step.FINAL || !aggregationNode.getGroupingKeys().isEmpty()) {
            return Optional.empty();
        }
        Optional<PlanNode> recurseToPartial = recurseToPartial(lookup.resolve(aggregationNode.getSource()), lookup, idAllocator);
        if (!recurseToPartial.isPresent()) {
            return Optional.empty();
        }
        PlanNode planNode2 = recurseToPartial.get();
        if (SystemSessionProperties.getTaskConcurrency(session) > 1) {
            planNode2 = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, new AggregationNode(idAllocator.getNextId(), ExchangeNode.partitionedExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, planNode2, new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), planNode2.getOutputSymbols())), inputsAsOutputs(aggregationNode.getAggregations()), aggregationNode.getGroupingSets(), AggregationNode.Step.INTERMEDIATE, aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol()));
        }
        return Optional.of(planNode.replaceChildren(ImmutableList.of(planNode2)));
    }

    private Optional<PlanNode> recurseToPartial(PlanNode planNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator) {
        if ((planNode instanceof AggregationNode) && ((AggregationNode) planNode).getStep() == AggregationNode.Step.PARTIAL) {
            return Optional.of(addGatheringIntermediate((AggregationNode) planNode, planNodeIdAllocator));
        }
        if (!(planNode instanceof ExchangeNode) && !(planNode instanceof ProjectNode)) {
            return Optional.empty();
        }
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator<PlanNode> it2 = planNode.getSources().iterator();
        while (it2.hasNext()) {
            Optional<PlanNode> recurseToPartial = recurseToPartial(lookup.resolve(it2.next()), lookup, planNodeIdAllocator);
            if (!recurseToPartial.isPresent()) {
                return Optional.empty();
            }
            builder.add((ImmutableList.Builder) recurseToPartial.get());
        }
        return Optional.of(planNode.replaceChildren(builder.build()));
    }

    private PlanNode addGatheringIntermediate(AggregationNode aggregationNode, PlanNodeIdAllocator planNodeIdAllocator) {
        Verify.verify(aggregationNode.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation", new Object[0]);
        return new AggregationNode(planNodeIdAllocator.getNextId(), ExchangeNode.gatheringExchange(planNodeIdAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregationNode), outputsAsInputs(aggregationNode.getAggregations()), aggregationNode.getGroupingSets(), AggregationNode.Step.INTERMEDIATE, aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol());
    }

    private static Map<Symbol, AggregationNode.Aggregation> outputsAsInputs(Map<Symbol, AggregationNode.Aggregation> map) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : map.entrySet()) {
            Symbol key = entry.getKey();
            AggregationNode.Aggregation value = entry.getValue();
            builder.put(key, new AggregationNode.Aggregation(new FunctionCall(QualifiedName.of(value.getSignature().getName()), ImmutableList.of(key.toSymbolReference())), value.getSignature(), Optional.empty()));
        }
        return builder.build();
    }

    private static Map<Symbol, AggregationNode.Aggregation> inputsAsOutputs(Map<Symbol, AggregationNode.Aggregation> map) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : map.entrySet()) {
            builder.put((Symbol) Iterables.getOnlyElement(SymbolsExtractor.extractAll(entry.getValue().getCall())), entry.getValue());
        }
        return builder.build();
    }
}
