/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolsExtractor;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.Util;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class PushPartialAggregationThroughJoin
implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation();

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isPushAggregationThroughJoin(session);
    }

    @Override
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        if (aggregationNode.getStep() != AggregationNode.Step.PARTIAL || aggregationNode.getGroupingSets().size() != 1) {
            return Rule.Result.empty();
        }
        if (aggregationNode.getHashSymbol().isPresent()) {
            return Rule.Result.empty();
        }
        PlanNode childNode = context.getLookup().resolve(aggregationNode.getSource());
        if (!(childNode instanceof JoinNode)) {
            return Rule.Result.empty();
        }
        JoinNode joinNode = (JoinNode)childNode;
        if (joinNode.getType() != JoinNode.Type.INNER) {
            return Rule.Result.empty();
        }
        if (this.allAggregationsOn(aggregationNode.getAggregations(), joinNode.getLeft().getOutputSymbols())) {
            return Rule.Result.ofPlanNode(this.pushPartialToLeftChild(aggregationNode, joinNode, context));
        }
        if (this.allAggregationsOn(aggregationNode.getAggregations(), joinNode.getRight().getOutputSymbols())) {
            return Rule.Result.ofPlanNode(this.pushPartialToRightChild(aggregationNode, joinNode, context));
        }
        return Rule.Result.empty();
    }

    private boolean allAggregationsOn(Map<Symbol, AggregationNode.Aggregation> aggregations, List<Symbol> symbols) {
        Set<Symbol> inputs = SymbolsExtractor.extractUnique((Iterable)aggregations.values().stream().map(AggregationNode.Aggregation::getCall).collect(ImmutableList.toImmutableList()));
        return symbols.containsAll(inputs);
    }

    private PlanNode pushPartialToLeftChild(AggregationNode node, JoinNode child, Rule.Context context) {
        ImmutableSet joinLeftChildSymbols = ImmutableSet.copyOf(child.getLeft().getOutputSymbols());
        List<Symbol> groupingSet = this.getPushedDownGroupingSet(node, (Set<Symbol>)joinLeftChildSymbols, (Set<Symbol>)Sets.intersection(this.getJoinRequiredSymbols(child), (Set)joinLeftChildSymbols));
        AggregationNode pushedAggregation = this.replaceAggregationSource(node, child.getLeft(), groupingSet);
        return this.pushPartialToJoin(node, child, pushedAggregation, child.getRight(), context);
    }

    private PlanNode pushPartialToRightChild(AggregationNode node, JoinNode child, Rule.Context context) {
        ImmutableSet joinRightChildSymbols = ImmutableSet.copyOf(child.getRight().getOutputSymbols());
        List<Symbol> groupingSet = this.getPushedDownGroupingSet(node, (Set<Symbol>)joinRightChildSymbols, (Set<Symbol>)Sets.intersection(this.getJoinRequiredSymbols(child), (Set)joinRightChildSymbols));
        AggregationNode pushedAggregation = this.replaceAggregationSource(node, child.getRight(), groupingSet);
        return this.pushPartialToJoin(node, child, child.getLeft(), pushedAggregation, context);
    }

    private Set<Symbol> getJoinRequiredSymbols(JoinNode node) {
        return (Set)Streams.concat((Stream[])new Stream[]{node.getCriteria().stream().map(JoinNode.EquiJoinClause::getLeft), node.getCriteria().stream().map(JoinNode.EquiJoinClause::getRight), node.getFilter().map(SymbolsExtractor::extractUnique).orElse((Set)ImmutableSet.of()).stream(), node.getLeftHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream(), node.getRightHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream()}).collect(ImmutableSet.toImmutableSet());
    }

    private List<Symbol> getPushedDownGroupingSet(AggregationNode aggregation, Set<Symbol> availableSymbols, Set<Symbol> requiredJoinSymbols) {
        List groupingSet = (List)Iterables.getOnlyElement(aggregation.getGroupingSets());
        List<Symbol> pushedDownGroupingSet = groupingSet.stream().filter(availableSymbols::contains).collect(Collectors.toList());
        HashSet existingSymbols = new HashSet(pushedDownGroupingSet);
        requiredJoinSymbols.stream().filter(existingSymbols::add).forEach(pushedDownGroupingSet::add);
        return pushedDownGroupingSet;
    }

    private AggregationNode replaceAggregationSource(AggregationNode aggregation, PlanNode source, List<Symbol> groupingSet) {
        return new AggregationNode(aggregation.getId(), source, aggregation.getAggregations(), (List<List<Symbol>>)ImmutableList.of(groupingSet), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol());
    }

    private PlanNode pushPartialToJoin(AggregationNode aggregation, JoinNode child, PlanNode leftChild, PlanNode rightChild, Rule.Context context) {
        JoinNode joinNode = new JoinNode(child.getId(), child.getType(), leftChild, rightChild, child.getCriteria(), (List<Symbol>)ImmutableList.builder().addAll(leftChild.getOutputSymbols()).addAll(rightChild.getOutputSymbols()).build(), child.getFilter(), child.getLeftHashSymbol(), child.getRightHashSymbol(), child.getDistributionType());
        return Util.restrictOutputs(context.getIdAllocator(), joinNode, (Set<Symbol>)ImmutableSet.copyOf(aggregation.getOutputSymbols())).orElse(joinNode);
    }
}

