package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.planner.ExpressionSymbolInliner;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.DistinctOutputQueryUtil;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.NullLiteral;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.class */
public class PushAggregationThroughOuterJoin implements Rule {
    private static final Pattern PATTERN = Pattern.typeOf(AggregationNode.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin$MappedAggregationInfo.class */
    public static class MappedAggregationInfo {
        private final AggregationNode aggregationNode;
        private final Map<Symbol, Symbol> symbolMapping;

        public MappedAggregationInfo(AggregationNode aggregationNode, Map<Symbol, Symbol> map) {
            this.aggregationNode = aggregationNode;
            this.symbolMapping = map;
        }

        public Map<Symbol, Symbol> getSymbolMapping() {
            return this.symbolMapping;
        }

        public AggregationNode getAggregation() {
            return this.aggregationNode;
        }
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule, com.facebook.presto.matching.Matchable
    public Pattern getPattern() {
        return PATTERN;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.shouldPushAggregationThroughJoin(session);
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Optional<PlanNode> apply(PlanNode planNode, Rule.Context context) {
        if (!(planNode instanceof AggregationNode)) {
            return Optional.empty();
        }
        AggregationNode aggregationNode = (AggregationNode) planNode;
        PlanNode resolve = context.getLookup().resolve(aggregationNode.getSource());
        if (!(resolve instanceof JoinNode)) {
            return Optional.empty();
        }
        JoinNode joinNode = (JoinNode) resolve;
        if (!joinNode.getFilter().isPresent() && ((joinNode.getType() == JoinNode.Type.LEFT || joinNode.getType() == JoinNode.Type.RIGHT) && groupsOnAllOuterTableColumns(aggregationNode, context.getLookup().resolve(getOuterTable(joinNode))))) {
            PlanNode resolve2 = context.getLookup().resolve(getOuterTable(joinNode));
            Lookup lookup = context.getLookup();
            lookup.getClass();
            if (DistinctOutputQueryUtil.isDistinct(resolve2, lookup::resolve)) {
                AggregationNode aggregationNode2 = new AggregationNode(planNode.getId(), getInnerTable(joinNode), aggregationNode.getAggregations(), ImmutableList.of((List) joinNode.getCriteria().stream().map(joinNode.getType() == JoinNode.Type.RIGHT ? (v0) -> {
                    return v0.getLeft();
                } : (v0) -> {
                    return v0.getRight();
                }).collect(ImmutableList.toImmutableList())), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol());
                return Optional.of(coalesceWithNullAggregation(aggregationNode2, joinNode.getType() == JoinNode.Type.LEFT ? new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), aggregationNode2, joinNode.getCriteria(), ImmutableList.builder().addAll((Iterable) joinNode.getLeft().getOutputSymbols()).addAll((Iterable) aggregationNode2.getAggregations().keySet()).build(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType()) : new JoinNode(joinNode.getId(), joinNode.getType(), aggregationNode2, joinNode.getRight(), joinNode.getCriteria(), ImmutableList.builder().addAll((Iterable) aggregationNode2.getAggregations().keySet()).addAll((Iterable) joinNode.getRight().getOutputSymbols()).build(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType()), context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup()));
            }
        }
        return Optional.empty();
    }

    private static PlanNode getInnerTable(JoinNode joinNode) {
        Preconditions.checkState(joinNode.getType() == JoinNode.Type.LEFT || joinNode.getType() == JoinNode.Type.RIGHT, "expected LEFT or RIGHT JOIN");
        return joinNode.getType().equals(JoinNode.Type.LEFT) ? joinNode.getRight() : joinNode.getLeft();
    }

    private static PlanNode getOuterTable(JoinNode joinNode) {
        Preconditions.checkState(joinNode.getType() == JoinNode.Type.LEFT || joinNode.getType() == JoinNode.Type.RIGHT, "expected LEFT or RIGHT JOIN");
        return joinNode.getType().equals(JoinNode.Type.LEFT) ? joinNode.getLeft() : joinNode.getRight();
    }

    private static boolean groupsOnAllOuterTableColumns(AggregationNode aggregationNode, PlanNode planNode) {
        return new HashSet(aggregationNode.getGroupingKeys()).equals(new HashSet(planNode.getOutputSymbols()));
    }

    private PlanNode coalesceWithNullAggregation(AggregationNode aggregationNode, PlanNode planNode, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, Lookup lookup) {
        MappedAggregationInfo createAggregationOverNull = createAggregationOverNull(aggregationNode, symbolAllocator, planNodeIdAllocator, lookup);
        AggregationNode aggregation = createAggregationOverNull.getAggregation();
        Map<Symbol, Symbol> symbolMapping = createAggregationOverNull.getSymbolMapping();
        JoinNode joinNode = new JoinNode(planNodeIdAllocator.getNextId(), JoinNode.Type.INNER, planNode, aggregation, ImmutableList.of(), ImmutableList.builder().addAll((Iterable) planNode.getOutputSymbols()).addAll((Iterable) aggregation.getOutputSymbols()).build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
        Assignments.Builder builder = Assignments.builder();
        for (Symbol symbol : planNode.getOutputSymbols()) {
            if (aggregationNode.getAggregations().containsKey(symbol)) {
                builder.put(symbol, new CoalesceExpression(symbol.toSymbolReference(), symbolMapping.get(symbol).toSymbolReference()));
            } else {
                builder.put(symbol, symbol.toSymbolReference());
            }
        }
        return new ProjectNode(planNodeIdAllocator.getNextId(), joinNode, builder.build());
    }

    private MappedAggregationInfo createAggregationOverNull(AggregationNode aggregationNode, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, Lookup lookup) {
        NullLiteral nullLiteral = new NullLiteral();
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableList.Builder builder2 = ImmutableList.builder();
        ImmutableMap.Builder builder3 = ImmutableMap.builder();
        for (Symbol symbol : lookup.resolve(aggregationNode.getSource()).getOutputSymbols()) {
            builder2.add((ImmutableList.Builder) nullLiteral);
            Symbol newSymbol = symbolAllocator.newSymbol(nullLiteral, symbolAllocator.getTypes().get(symbol));
            builder.add((ImmutableList.Builder) newSymbol);
            builder3.put(symbol, newSymbol.toSymbolReference());
        }
        ValuesNode valuesNode = new ValuesNode(planNodeIdAllocator.getNextId(), builder.build(), ImmutableList.of(builder2.build()));
        ImmutableMap build = builder3.build();
        ImmutableMap.Builder builder4 = ImmutableMap.builder();
        ImmutableMap.Builder builder5 = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            Symbol key = entry.getKey();
            AggregationNode.Aggregation value = entry.getValue();
            AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation((FunctionCall) new ExpressionSymbolInliner(build).rewrite(value.getCall()), value.getSignature(), value.getMask().map(symbol2 -> {
                return Symbol.from((Expression) build.get(symbol2));
            }));
            Symbol newSymbol2 = symbolAllocator.newSymbol(aggregation.getCall(), symbolAllocator.getTypes().get(key));
            builder5.put(newSymbol2, aggregation);
            builder4.put(key, newSymbol2);
        }
        return new MappedAggregationInfo(new AggregationNode(planNodeIdAllocator.getNextId(), valuesNode, builder5.build(), ImmutableList.of(ImmutableList.of()), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), builder4.build());
    }
}
