/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.planner.ExpressionSymbolInliner;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.DistinctOutputQueryUtil;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.NullLiteral;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class PushAggregationThroughOuterJoin
implements Rule<AggregationNode> {
    private static final Capture<JoinNode> JOIN = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.source().matching(Patterns.join().capturedAs(JOIN)));

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.shouldPushAggregationThroughJoin(session);
    }

    @Override
    public Rule.Result apply(AggregationNode aggregation, Captures captures, Rule.Context context) {
        JoinNode join;
        block3: {
            block2: {
                join = (JoinNode)captures.get(JOIN);
                if (join.getFilter().isPresent() || join.getType() != JoinNode.Type.LEFT && join.getType() != JoinNode.Type.RIGHT || !PushAggregationThroughOuterJoin.groupsOnAllOuterTableColumns(aggregation, context.getLookup().resolve(PushAggregationThroughOuterJoin.getOuterTable(join)))) break block2;
                if (DistinctOutputQueryUtil.isDistinct(context.getLookup().resolve(PushAggregationThroughOuterJoin.getOuterTable(join)), context.getLookup()::resolve)) break block3;
            }
            return Rule.Result.empty();
        }
        List groupingKeys = (List)join.getCriteria().stream().map(join.getType() == JoinNode.Type.RIGHT ? JoinNode.EquiJoinClause::getLeft : JoinNode.EquiJoinClause::getRight).collect(ImmutableList.toImmutableList());
        AggregationNode rewrittenAggregation = new AggregationNode(aggregation.getId(), PushAggregationThroughOuterJoin.getInnerTable(join), aggregation.getAggregations(), (List<List<Symbol>>)ImmutableList.of((Object)groupingKeys), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol());
        JoinNode rewrittenJoin = join.getType() == JoinNode.Type.LEFT ? new JoinNode(join.getId(), join.getType(), join.getLeft(), rewrittenAggregation, join.getCriteria(), (List<Symbol>)ImmutableList.builder().addAll(join.getLeft().getOutputSymbols()).addAll(rewrittenAggregation.getAggregations().keySet()).build(), join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType()) : new JoinNode(join.getId(), join.getType(), rewrittenAggregation, join.getRight(), join.getCriteria(), (List<Symbol>)ImmutableList.builder().addAll(rewrittenAggregation.getAggregations().keySet()).addAll(join.getRight().getOutputSymbols()).build(), join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType());
        return Rule.Result.ofPlanNode(this.coalesceWithNullAggregation(rewrittenAggregation, rewrittenJoin, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup()));
    }

    private static PlanNode getInnerTable(JoinNode join) {
        Preconditions.checkState((join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT ? 1 : 0) != 0, (Object)"expected LEFT or RIGHT JOIN");
        PlanNode innerNode = join.getType().equals((Object)JoinNode.Type.LEFT) ? join.getRight() : join.getLeft();
        return innerNode;
    }

    private static PlanNode getOuterTable(JoinNode join) {
        Preconditions.checkState((join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT ? 1 : 0) != 0, (Object)"expected LEFT or RIGHT JOIN");
        PlanNode outerNode = join.getType().equals((Object)JoinNode.Type.LEFT) ? join.getLeft() : join.getRight();
        return outerNode;
    }

    private static boolean groupsOnAllOuterTableColumns(AggregationNode node, PlanNode outerTable) {
        return new HashSet<Symbol>(node.getGroupingKeys()).equals(new HashSet<Symbol>(outerTable.getOutputSymbols()));
    }

    private PlanNode coalesceWithNullAggregation(AggregationNode aggregationNode, PlanNode outerJoin, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
        MappedAggregationInfo aggregationOverNullInfo = this.createAggregationOverNull(aggregationNode, symbolAllocator, idAllocator, lookup);
        AggregationNode aggregationOverNull = aggregationOverNullInfo.getAggregation();
        Map<Symbol, Symbol> sourceAggregationToOverNullMapping = aggregationOverNullInfo.getSymbolMapping();
        JoinNode crossJoin = new JoinNode(idAllocator.getNextId(), JoinNode.Type.INNER, outerJoin, aggregationOverNull, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), (List<Symbol>)ImmutableList.builder().addAll(outerJoin.getOutputSymbols()).addAll(aggregationOverNull.getOutputSymbols()).build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
        Assignments.Builder assignmentsBuilder = Assignments.builder();
        for (Symbol symbol : outerJoin.getOutputSymbols()) {
            if (aggregationNode.getAggregations().containsKey(symbol)) {
                assignmentsBuilder.put(symbol, (Expression)new CoalesceExpression((Expression)symbol.toSymbolReference(), (Expression)sourceAggregationToOverNullMapping.get(symbol).toSymbolReference(), new Expression[0]));
                continue;
            }
            assignmentsBuilder.put(symbol, (Expression)symbol.toSymbolReference());
        }
        return new ProjectNode(idAllocator.getNextId(), crossJoin, assignmentsBuilder.build());
    }

    private MappedAggregationInfo createAggregationOverNull(AggregationNode referenceAggregation, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
        NullLiteral nullLiteral = new NullLiteral();
        ImmutableList.Builder nullSymbols = ImmutableList.builder();
        ImmutableList.Builder nullLiterals = ImmutableList.builder();
        ImmutableMap.Builder sourcesSymbolMappingBuilder = ImmutableMap.builder();
        for (Symbol sourceSymbol : lookup.resolve(referenceAggregation.getSource()).getOutputSymbols()) {
            nullLiterals.add((Object)nullLiteral);
            Symbol nullSymbol = symbolAllocator.newSymbol((Expression)nullLiteral, symbolAllocator.getTypes().get(sourceSymbol));
            nullSymbols.add((Object)nullSymbol);
            sourcesSymbolMappingBuilder.put((Object)sourceSymbol, (Object)nullSymbol.toSymbolReference());
        }
        ValuesNode nullRow = new ValuesNode(idAllocator.getNextId(), (List<Symbol>)nullSymbols.build(), (List<List<Expression>>)ImmutableList.of((Object)nullLiterals.build()));
        ImmutableMap sourcesSymbolMapping = sourcesSymbolMappingBuilder.build();
        ImmutableMap.Builder aggregationsSymbolMappingBuilder = ImmutableMap.builder();
        ImmutableMap.Builder aggregationsOverNullBuilder = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : referenceAggregation.getAggregations().entrySet()) {
            Symbol aggregationSymbol = entry.getKey();
            AggregationNode.Aggregation aggregation = entry.getValue();
            AggregationNode.Aggregation overNullAggregation = new AggregationNode.Aggregation((FunctionCall)ExpressionSymbolInliner.inlineSymbols((Map<Symbol, ? extends Expression>)sourcesSymbolMapping, (Expression)aggregation.getCall()), aggregation.getSignature(), aggregation.getMask().map(arg_0 -> PushAggregationThroughOuterJoin.lambda$createAggregationOverNull$0((Map)sourcesSymbolMapping, arg_0)));
            Symbol overNullSymbol = symbolAllocator.newSymbol((Expression)overNullAggregation.getCall(), symbolAllocator.getTypes().get(aggregationSymbol));
            aggregationsOverNullBuilder.put((Object)overNullSymbol, (Object)overNullAggregation);
            aggregationsSymbolMappingBuilder.put((Object)aggregationSymbol, (Object)overNullSymbol);
        }
        ImmutableMap aggregationsSymbolMapping = aggregationsSymbolMappingBuilder.build();
        AggregationNode aggregationOverNullRow = new AggregationNode(idAllocator.getNextId(), nullRow, (Map<Symbol, AggregationNode.Aggregation>)aggregationsOverNullBuilder.build(), (List<List<Symbol>>)ImmutableList.of((Object)ImmutableList.of()), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
        return new MappedAggregationInfo(aggregationOverNullRow, (Map<Symbol, Symbol>)aggregationsSymbolMapping);
    }

    private static /* synthetic */ Symbol lambda$createAggregationOverNull$0(Map sourcesSymbolMapping, Symbol x) {
        return Symbol.from((Expression)sourcesSymbolMapping.get(x));
    }

    private static class MappedAggregationInfo {
        private final AggregationNode aggregationNode;
        private final Map<Symbol, Symbol> symbolMapping;

        public MappedAggregationInfo(AggregationNode aggregationNode, Map<Symbol, Symbol> symbolMapping) {
            this.aggregationNode = aggregationNode;
            this.symbolMapping = symbolMapping;
        }

        public Map<Symbol, Symbol> getSymbolMapping() {
            return this.symbolMapping;
        }

        public AggregationNode getAggregation() {
            return this.aggregationNode;
        }
    }
}

