/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.Util;
import com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class PushPartialAggregationThroughJoin
implements Rule<AggregationNode> {
    private static final Capture<JoinNode> JOIN_NODE = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(PushPartialAggregationThroughJoin::isSupportedAggregationNode).with(Patterns.source().matching(Patterns.join().capturedAs(JOIN_NODE)));

    private static boolean isSupportedAggregationNode(AggregationNode aggregationNode) {
        if (aggregationNode.isStreamable() || aggregationNode.isSegmentedAggregationEligible()) {
            return false;
        }
        if (aggregationNode.getHashVariable().isPresent()) {
            return false;
        }
        return aggregationNode.getStep() == AggregationNode.Step.PARTIAL && aggregationNode.getGroupingSetCount() == 1;
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isPushAggregationThroughJoin(session);
    }

    @Override
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        JoinNode joinNode = (JoinNode)((Object)captures.get(JOIN_NODE));
        if (joinNode.getType() != JoinType.INNER) {
            return Rule.Result.empty();
        }
        if (this.allAggregationsOn(aggregationNode.getAggregations(), joinNode.getLeft().getOutputVariables(), TypeProvider.viewOf(context.getVariableAllocator().getVariables()))) {
            return Rule.Result.ofPlanNode(this.pushPartialToLeftChild(aggregationNode, joinNode, context));
        }
        if (this.allAggregationsOn(aggregationNode.getAggregations(), joinNode.getRight().getOutputVariables(), TypeProvider.viewOf(context.getVariableAllocator().getVariables()))) {
            return Rule.Result.ofPlanNode(this.pushPartialToRightChild(aggregationNode, joinNode, context));
        }
        return Rule.Result.empty();
    }

    private boolean allAggregationsOn(Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregations, List<VariableReferenceExpression> variables, TypeProvider types) {
        Set inputs = (Set)aggregations.values().stream().map(aggregation -> AggregationNodeUtils.extractAggregationUniqueVariables(aggregation, types)).flatMap(Collection::stream).collect(ImmutableSet.toImmutableSet());
        return variables.containsAll(inputs);
    }

    private PlanNode pushPartialToLeftChild(AggregationNode node, JoinNode child, Rule.Context context) {
        ImmutableSet joinLeftChildVariables = ImmutableSet.copyOf((Collection)child.getLeft().getOutputVariables());
        List<VariableReferenceExpression> groupingSet = this.getPushedDownGroupingSet(node, (Set<VariableReferenceExpression>)joinLeftChildVariables, (Set<VariableReferenceExpression>)Sets.intersection(this.getJoinRequiredVariables(child), (Set)joinLeftChildVariables));
        AggregationNode pushedAggregation = this.replaceAggregationSource(node, child.getLeft(), groupingSet);
        return this.pushPartialToJoin(node, child, (PlanNode)pushedAggregation, child.getRight(), context);
    }

    private PlanNode pushPartialToRightChild(AggregationNode node, JoinNode child, Rule.Context context) {
        ImmutableSet joinRightChildVariables = ImmutableSet.copyOf((Collection)child.getRight().getOutputVariables());
        List<VariableReferenceExpression> groupingSet = this.getPushedDownGroupingSet(node, (Set<VariableReferenceExpression>)joinRightChildVariables, (Set<VariableReferenceExpression>)Sets.intersection(this.getJoinRequiredVariables(child), (Set)joinRightChildVariables));
        AggregationNode pushedAggregation = this.replaceAggregationSource(node, child.getRight(), groupingSet);
        return this.pushPartialToJoin(node, child, child.getLeft(), (PlanNode)pushedAggregation, context);
    }

    private Set<VariableReferenceExpression> getJoinRequiredVariables(JoinNode node) {
        return (Set)Streams.concat((Stream[])new Stream[]{node.getCriteria().stream().map(EquiJoinClause::getLeft), node.getCriteria().stream().map(EquiJoinClause::getRight), node.getFilter().map(expression -> VariablesExtractor.extractUnique(expression)).orElse((Set)ImmutableSet.of()).stream(), node.getLeftHashVariable().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream(), node.getRightHashVariable().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream()}).collect(ImmutableSet.toImmutableSet());
    }

    private List<VariableReferenceExpression> getPushedDownGroupingSet(AggregationNode aggregation, Set<VariableReferenceExpression> availableVariables, Set<VariableReferenceExpression> requiredJoinVariables) {
        List groupingSet = aggregation.getGroupingKeys();
        List<VariableReferenceExpression> pushedDownGroupingSet = groupingSet.stream().filter(availableVariables::contains).collect(Collectors.toList());
        HashSet existingVariables = new HashSet(pushedDownGroupingSet);
        requiredJoinVariables.stream().filter(existingVariables::add).forEach(pushedDownGroupingSet::add);
        return pushedDownGroupingSet;
    }

    private AggregationNode replaceAggregationSource(AggregationNode aggregation, PlanNode source, List<VariableReferenceExpression> groupingKeys) {
        return new AggregationNode(aggregation.getSourceLocation(), aggregation.getId(), source, aggregation.getAggregations(), AggregationNode.singleGroupingSet(groupingKeys), (List)ImmutableList.of(), aggregation.getStep(), aggregation.getHashVariable(), aggregation.getGroupIdVariable(), aggregation.getAggregationId());
    }

    private PlanNode pushPartialToJoin(AggregationNode aggregation, JoinNode child, PlanNode leftChild, PlanNode rightChild, Rule.Context context) {
        JoinNode joinNode = new JoinNode(child.getSourceLocation(), child.getId(), child.getType(), leftChild, rightChild, child.getCriteria(), (List<VariableReferenceExpression>)ImmutableList.builder().addAll((Iterable)leftChild.getOutputVariables()).addAll((Iterable)rightChild.getOutputVariables()).build(), child.getFilter(), child.getLeftHashVariable(), child.getRightHashVariable(), child.getDistributionType(), child.getDynamicFilters());
        return Util.restrictOutputs(context.getIdAllocator(), joinNode, (Set<VariableReferenceExpression>)ImmutableSet.copyOf((Collection)aggregation.getOutputVariables())).orElse(joinNode);
    }
}

