/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.planner.Partitioning;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolsExtractor;
import com.facebook.presto.sql.planner.SystemPartitioningHandle;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class AddIntermediateAggregations
implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.Aggregation.step().equalTo((Object)AggregationNode.Step.FINAL)).with(Pattern.empty(Patterns.Aggregation.groupingKeys())).matching(node -> !node.hasOrderings());

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isEnableIntermediateAggregations(session);
    }

    @Override
    public Rule.Result apply(AggregationNode aggregation, Captures captures, Rule.Context context) {
        Lookup lookup = context.getLookup();
        PlanNodeIdAllocator idAllocator = context.getIdAllocator();
        Session session = context.getSession();
        Optional<PlanNode> rewrittenSource = this.recurseToPartial(lookup.resolve(aggregation.getSource()), lookup, idAllocator);
        if (!rewrittenSource.isPresent()) {
            return Rule.Result.empty();
        }
        PlanNode source = rewrittenSource.get();
        if (SystemSessionProperties.getTaskConcurrency(session) > 1) {
            source = ExchangeNode.partitionedExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source, new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION, (List<Symbol>)ImmutableList.of()), source.getOutputSymbols()));
            source = new AggregationNode(idAllocator.getNextId(), source, AddIntermediateAggregations.inputsAsOutputs(aggregation.getAggregations()), aggregation.getGroupingSets(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol());
            source = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source);
        }
        return Rule.Result.ofPlanNode(aggregation.replaceChildren((List<PlanNode>)ImmutableList.of((Object)source)));
    }

    private Optional<PlanNode> recurseToPartial(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator) {
        if (node instanceof AggregationNode && ((AggregationNode)node).getStep() == AggregationNode.Step.PARTIAL) {
            return Optional.of(this.addGatheringIntermediate((AggregationNode)node, idAllocator));
        }
        if (!(node instanceof ExchangeNode) && !(node instanceof ProjectNode)) {
            return Optional.empty();
        }
        ImmutableList.Builder builder = ImmutableList.builder();
        for (PlanNode source : node.getSources()) {
            Optional<PlanNode> planNode = this.recurseToPartial(lookup.resolve(source), lookup, idAllocator);
            if (!planNode.isPresent()) {
                return Optional.empty();
            }
            builder.add((Object)planNode.get());
        }
        return Optional.of(node.replaceChildren((List<PlanNode>)builder.build()));
    }

    private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeIdAllocator idAllocator) {
        Verify.verify((boolean)aggregation.getGroupingKeys().isEmpty(), (String)"Should be an un-grouped aggregation", (Object[])new Object[0]);
        ExchangeNode gatheringExchange = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregation);
        return new AggregationNode(idAllocator.getNextId(), gatheringExchange, AddIntermediateAggregations.outputsAsInputs(aggregation.getAggregations()), aggregation.getGroupingSets(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol());
    }

    private static Map<Symbol, AggregationNode.Aggregation> outputsAsInputs(Map<Symbol, AggregationNode.Aggregation> assignments) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : assignments.entrySet()) {
            Symbol output = entry.getKey();
            AggregationNode.Aggregation aggregation = entry.getValue();
            Preconditions.checkState((!aggregation.getCall().getOrderBy().isPresent() ? 1 : 0) != 0, (Object)"Intermediate aggregation does not support ORDER BY");
            builder.put((Object)output, (Object)new AggregationNode.Aggregation(new FunctionCall(QualifiedName.of((String)aggregation.getSignature().getName()), (List)ImmutableList.of((Object)output.toSymbolReference())), aggregation.getSignature(), Optional.empty()));
        }
        return builder.build();
    }

    private static Map<Symbol, AggregationNode.Aggregation> inputsAsOutputs(Map<Symbol, AggregationNode.Aggregation> assignments) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : assignments.entrySet()) {
            Symbol input = (Symbol)Iterables.getOnlyElement(SymbolsExtractor.extractAll((Expression)entry.getValue().getCall()));
            builder.put((Object)input, (Object)entry.getValue());
        }
        return builder.build();
    }
}

