/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.block.SortOrder;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.Ordering;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.RowExpressionVariableInliner;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.DistinctOutputQueryUtil;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.Expressions;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class PushAggregationThroughOuterJoin
implements Rule<AggregationNode> {
    private static final Capture<JoinNode> JOIN = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.source().matching(Patterns.join().capturedAs(JOIN)));
    private final FunctionAndTypeManager functionAndTypeManager;

    public PushAggregationThroughOuterJoin(FunctionAndTypeManager functionAndTypeManager) {
        this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionManager is null");
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.shouldPushAggregationThroughJoin(session);
    }

    @Override
    public Rule.Result apply(AggregationNode aggregation, Captures captures, Rule.Context context) {
        JoinNode rewrittenJoin;
        JoinNode join;
        block5: {
            block4: {
                join = (JoinNode)((Object)captures.get(JOIN));
                if (join.getFilter().isPresent() || join.getType() != JoinNode.Type.LEFT && join.getType() != JoinNode.Type.RIGHT || !PushAggregationThroughOuterJoin.groupsOnAllColumns(aggregation, PushAggregationThroughOuterJoin.getOuterTable(join).getOutputVariables())) break block4;
                if (DistinctOutputQueryUtil.isDistinct(context.getLookup().resolve(PushAggregationThroughOuterJoin.getOuterTable(join)), context.getLookup()::resolve)) break block5;
            }
            return Rule.Result.empty();
        }
        List groupingKeys = (List)join.getCriteria().stream().map(join.getType() == JoinNode.Type.RIGHT ? JoinNode.EquiJoinClause::getLeft : JoinNode.EquiJoinClause::getRight).collect(ImmutableList.toImmutableList());
        AggregationNode rewrittenAggregation = new AggregationNode(aggregation.getId(), PushAggregationThroughOuterJoin.getInnerTable(join), aggregation.getAggregations(), AggregationNode.singleGroupingSet((List)groupingKeys), (List)ImmutableList.of(), aggregation.getStep(), aggregation.getHashVariable(), aggregation.getGroupIdVariable());
        Optional<PlanNode> resultNode = this.coalesceWithNullAggregation(rewrittenAggregation, rewrittenJoin = join.getType() == JoinNode.Type.LEFT ? new JoinNode(join.getId(), join.getType(), join.getLeft(), (PlanNode)rewrittenAggregation, join.getCriteria(), (List<VariableReferenceExpression>)ImmutableList.builder().addAll((Iterable)join.getLeft().getOutputVariables()).addAll(rewrittenAggregation.getAggregations().keySet()).build(), join.getFilter(), join.getLeftHashVariable(), join.getRightHashVariable(), join.getDistributionType(), join.getDynamicFilters()) : new JoinNode(join.getId(), join.getType(), (PlanNode)rewrittenAggregation, join.getRight(), join.getCriteria(), (List<VariableReferenceExpression>)ImmutableList.builder().addAll(rewrittenAggregation.getAggregations().keySet()).addAll((Iterable)join.getRight().getOutputVariables()).build(), join.getFilter(), join.getLeftHashVariable(), join.getRightHashVariable(), join.getDistributionType(), join.getDynamicFilters()), context.getVariableAllocator(), context.getIdAllocator(), context.getLookup());
        if (!resultNode.isPresent()) {
            return Rule.Result.empty();
        }
        return Rule.Result.ofPlanNode(resultNode.get());
    }

    private static PlanNode getInnerTable(JoinNode join) {
        Preconditions.checkState((join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT ? 1 : 0) != 0, (Object)"expected LEFT or RIGHT JOIN");
        PlanNode innerNode = join.getType().equals((Object)JoinNode.Type.LEFT) ? join.getRight() : join.getLeft();
        return innerNode;
    }

    private static PlanNode getOuterTable(JoinNode join) {
        Preconditions.checkState((join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT ? 1 : 0) != 0, (Object)"expected LEFT or RIGHT JOIN");
        PlanNode outerNode = join.getType().equals((Object)JoinNode.Type.LEFT) ? join.getLeft() : join.getRight();
        return outerNode;
    }

    private static boolean groupsOnAllColumns(AggregationNode node, List<VariableReferenceExpression> columns) {
        return new HashSet(node.getGroupingKeys()).equals(new HashSet<VariableReferenceExpression>(columns));
    }

    private Optional<PlanNode> coalesceWithNullAggregation(AggregationNode aggregationNode, PlanNode outerJoin, PlanVariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
        Optional<MappedAggregationInfo> aggregationOverNullInfoResultNode = this.createAggregationOverNull(aggregationNode, variableAllocator, idAllocator, lookup);
        if (!aggregationOverNullInfoResultNode.isPresent()) {
            return Optional.empty();
        }
        MappedAggregationInfo aggregationOverNullInfo = aggregationOverNullInfoResultNode.get();
        AggregationNode aggregationOverNull = aggregationOverNullInfo.getAggregation();
        Map<VariableReferenceExpression, VariableReferenceExpression> sourceAggregationToOverNullMapping = aggregationOverNullInfo.getVariableMapping();
        JoinNode crossJoin = new JoinNode(idAllocator.getNextId(), JoinNode.Type.INNER, outerJoin, (PlanNode)aggregationOverNull, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), (List<VariableReferenceExpression>)ImmutableList.builder().addAll((Iterable)outerJoin.getOutputVariables()).addAll((Iterable)aggregationOverNull.getOutputVariables()).build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), (Map<String, VariableReferenceExpression>)ImmutableMap.of());
        Assignments.Builder assignmentsBuilder = Assignments.builder();
        for (VariableReferenceExpression variable : outerJoin.getOutputVariables()) {
            if (aggregationNode.getAggregations().keySet().contains(variable)) {
                assignmentsBuilder.put(variable, PushAggregationThroughOuterJoin.coalesce((List<RowExpression>)ImmutableList.of((Object)variable, (Object)sourceAggregationToOverNullMapping.get(variable))));
                continue;
            }
            assignmentsBuilder.put(variable, (RowExpression)variable);
        }
        return Optional.of(new ProjectNode(idAllocator.getNextId(), (PlanNode)crossJoin, assignmentsBuilder.build()));
    }

    private static RowExpression coalesce(List<RowExpression> expressions) {
        return new SpecialFormExpression(SpecialFormExpression.Form.COALESCE, expressions.get(0).getType(), expressions);
    }

    private Optional<MappedAggregationInfo> createAggregationOverNull(AggregationNode referenceAggregation, PlanVariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
        ImmutableList.Builder nullVariables = ImmutableList.builder();
        ImmutableList.Builder nullLiterals = ImmutableList.builder();
        ImmutableMap.Builder sourcesVariableMappingBuilder = ImmutableMap.builder();
        for (VariableReferenceExpression sourceVariable : referenceAggregation.getSource().getOutputVariables()) {
            ConstantExpression nullLiteral = Expressions.constantNull(sourceVariable.getType());
            nullLiterals.add((Object)nullLiteral);
            VariableReferenceExpression nullVariable = variableAllocator.newVariable((RowExpression)nullLiteral);
            nullVariables.add((Object)nullVariable);
            sourcesVariableMappingBuilder.put((Object)sourceVariable, (Object)nullVariable);
        }
        ValuesNode nullRow = new ValuesNode(idAllocator.getNextId(), (List)nullVariables.build(), (List)ImmutableList.of((Object)nullLiterals.build()));
        ImmutableMap sourcesVariableMapping = sourcesVariableMappingBuilder.build();
        ImmutableMap.Builder aggregationsVariableMappingBuilder = ImmutableMap.builder();
        ImmutableMap.Builder aggregationsOverNullBuilder = ImmutableMap.builder();
        for (Map.Entry entry : referenceAggregation.getAggregations().entrySet()) {
            VariableReferenceExpression aggregationVariable = (VariableReferenceExpression)entry.getKey();
            AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation)entry.getValue();
            if (!PushAggregationThroughOuterJoin.isUsingVariables(aggregation, sourcesVariableMapping.keySet())) {
                return Optional.empty();
            }
            AggregationNode.Aggregation overNullAggregation = new AggregationNode.Aggregation(new CallExpression(aggregation.getCall().getDisplayName(), aggregation.getCall().getFunctionHandle(), aggregation.getCall().getType(), (List)aggregation.getArguments().stream().map(arg_0 -> PushAggregationThroughOuterJoin.lambda$createAggregationOverNull$0((Map)sourcesVariableMapping, arg_0)).collect(ImmutableList.toImmutableList())), aggregation.getFilter().map(arg_0 -> PushAggregationThroughOuterJoin.lambda$createAggregationOverNull$1((Map)sourcesVariableMapping, arg_0)), aggregation.getOrderBy().map(arg_0 -> PushAggregationThroughOuterJoin.lambda$createAggregationOverNull$2((Map)sourcesVariableMapping, arg_0)), aggregation.isDistinct(), aggregation.getMask().map(arg_0 -> PushAggregationThroughOuterJoin.lambda$createAggregationOverNull$3((Map)sourcesVariableMapping, arg_0)));
            QualifiedObjectName functionName = this.functionAndTypeManager.getFunctionMetadata(overNullAggregation.getFunctionHandle()).getName();
            VariableReferenceExpression overNull = variableAllocator.newVariable(functionName.getObjectName(), aggregationVariable.getType());
            aggregationsOverNullBuilder.put((Object)overNull, (Object)overNullAggregation);
            aggregationsVariableMappingBuilder.put((Object)aggregationVariable, (Object)overNull);
        }
        ImmutableMap aggregationsSymbolMapping = aggregationsVariableMappingBuilder.build();
        AggregationNode aggregationOverNullRow = new AggregationNode(idAllocator.getNextId(), (PlanNode)nullRow, (Map)aggregationsOverNullBuilder.build(), AggregationNode.globalAggregation(), (List)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
        return Optional.of(new MappedAggregationInfo(aggregationOverNullRow, (Map<VariableReferenceExpression, VariableReferenceExpression>)aggregationsSymbolMapping));
    }

    private static OrderingScheme inlineOrderByVariables(Map<VariableReferenceExpression, VariableReferenceExpression> variableMapping, OrderingScheme orderingScheme) {
        ImmutableList.Builder orderBy = ImmutableList.builder();
        ImmutableMap.Builder ordering = new ImmutableMap.Builder();
        for (VariableReferenceExpression variable2 : orderingScheme.getOrderByVariables()) {
            VariableReferenceExpression translated = variableMapping.get(variable2);
            orderBy.add((Object)translated);
            ordering.put((Object)translated, (Object)orderingScheme.getOrdering(variable2));
        }
        ImmutableMap orderingMap = ordering.build();
        return new OrderingScheme((List)orderBy.build().stream().map(variable -> new Ordering(variable, (SortOrder)orderingMap.get(variable))).collect(ImmutableList.toImmutableList()));
    }

    private static boolean isUsingVariables(AggregationNode.Aggregation aggregation, Set<VariableReferenceExpression> sourceVariables) {
        HashSet<VariableReferenceExpression> inputVariables = new HashSet<VariableReferenceExpression>();
        for (RowExpression argument : aggregation.getArguments()) {
            if (!(argument instanceof VariableReferenceExpression)) continue;
            inputVariables.add((VariableReferenceExpression)argument);
        }
        return sourceVariables.stream().anyMatch(inputVariables::contains);
    }

    private static /* synthetic */ VariableReferenceExpression lambda$createAggregationOverNull$3(Map sourcesVariableMapping, VariableReferenceExpression x) {
        return new VariableReferenceExpression(((VariableReferenceExpression)sourcesVariableMapping.get(x)).getName(), x.getType());
    }

    private static /* synthetic */ OrderingScheme lambda$createAggregationOverNull$2(Map sourcesVariableMapping, OrderingScheme orderBy) {
        return PushAggregationThroughOuterJoin.inlineOrderByVariables(sourcesVariableMapping, orderBy);
    }

    private static /* synthetic */ RowExpression lambda$createAggregationOverNull$1(Map sourcesVariableMapping, RowExpression filter) {
        return RowExpressionVariableInliner.inlineVariables(sourcesVariableMapping, filter);
    }

    private static /* synthetic */ RowExpression lambda$createAggregationOverNull$0(Map sourcesVariableMapping, RowExpression argument) {
        return RowExpressionVariableInliner.inlineVariables(sourcesVariableMapping, argument);
    }

    private static class MappedAggregationInfo {
        private final AggregationNode aggregationNode;
        private final Map<VariableReferenceExpression, VariableReferenceExpression> variableMapping;

        public MappedAggregationInfo(AggregationNode aggregationNode, Map<VariableReferenceExpression, VariableReferenceExpression> variableMapping) {
            this.aggregationNode = aggregationNode;
            this.variableMapping = variableMapping;
        }

        public Map<VariableReferenceExpression, VariableReferenceExpression> getVariableMapping() {
            return this.variableMapping;
        }

        public AggregationNode getAggregation() {
            return this.aggregationNode;
        }
    }
}

