/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.ImmutableSortedSet;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

public class RewriteAggregationIfToFilter
implements Rule<AggregationNode> {
    private static final Capture<ProjectNode> CHILD = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.source().matching(Patterns.project().capturedAs(CHILD)));
    private final FunctionAndTypeManager functionAndTypeManager;
    private final RowExpressionDeterminismEvaluator rowExpressionDeterminismEvaluator;

    public RewriteAggregationIfToFilter(FunctionAndTypeManager functionAndTypeManager) {
        this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionManager is null");
        this.rowExpressionDeterminismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isAggregationIfToFilterRewriteEnabled(session);
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        ProjectNode sourceProject = (ProjectNode)captures.get(CHILD);
        Set aggregationsToRewrite = (Set)aggregationNode.getAggregations().values().stream().filter(aggregation -> this.shouldRewriteAggregation((AggregationNode.Aggregation)aggregation, sourceProject)).collect(ImmutableSet.toImmutableSet());
        if (aggregationsToRewrite.isEmpty()) {
            return Rule.Result.empty();
        }
        Map sourceAssignments = (Map)aggregationsToRewrite.stream().map(aggregation -> (VariableReferenceExpression)aggregation.getArguments().get(0)).collect(ImmutableSortedMap.toImmutableSortedMap(VariableReferenceExpression::compareTo, Function.identity(), variable -> sourceProject.getAssignments().get(variable), (left, right) -> left));
        Assignments.Builder newAssignments = Assignments.builder();
        newAssignments.putAll(sourceProject.getAssignments());
        HashMap<Object, VariableReferenceExpression> aggregationReferenceToConditionReference = new HashMap<Object, VariableReferenceExpression>();
        for (Map.Entry entry : sourceAssignments.entrySet()) {
            VariableReferenceExpression outputVariable = (VariableReferenceExpression)entry.getKey();
            SpecialFormExpression specialFormExpression = (SpecialFormExpression)entry.getValue();
            RowExpression condition = (RowExpression)specialFormExpression.getArguments().get(0);
            VariableReferenceExpression conditionReference = context.getVariableAllocator().newVariable(condition);
            newAssignments.put(conditionReference, condition);
            aggregationReferenceToConditionReference.put(outputVariable, conditionReference);
        }
        ImmutableMap.Builder aggregations = ImmutableMap.builder();
        ImmutableSortedSet.Builder masks = ImmutableSortedSet.naturalOrder();
        for (Map.Entry entry : aggregationNode.getAggregations().entrySet()) {
            VariableReferenceExpression output = (VariableReferenceExpression)entry.getKey();
            AggregationNode.Aggregation aggregation2 = (AggregationNode.Aggregation)entry.getValue();
            if (!aggregationsToRewrite.contains(aggregation2)) {
                aggregations.put((Object)output, (Object)aggregation2);
                continue;
            }
            VariableReferenceExpression aggregationReference = (VariableReferenceExpression)aggregation2.getArguments().get(0);
            CallExpression callExpression = aggregation2.getCall();
            VariableReferenceExpression mask = (VariableReferenceExpression)aggregationReferenceToConditionReference.get(aggregationReference);
            aggregations.put((Object)output, (Object)new AggregationNode.Aggregation(callExpression, Optional.empty(), aggregation2.getOrderBy(), aggregation2.isDistinct(), Optional.of(aggregationReferenceToConditionReference.get(aggregationReference))));
            masks.add((Object)mask);
        }
        ConstantExpression predicate = LogicalRowExpressions.TRUE_CONSTANT;
        if (!aggregationNode.hasNonEmptyGroupingSet() && aggregationsToRewrite.size() == aggregationNode.getAggregations().size()) {
            predicate = LogicalRowExpressions.or((Collection)masks.build());
        }
        return Rule.Result.ofPlanNode((PlanNode)new AggregationNode(context.getIdAllocator().getNextId(), (PlanNode)new FilterNode(context.getIdAllocator().getNextId(), (PlanNode)new ProjectNode(context.getIdAllocator().getNextId(), sourceProject.getSource(), newAssignments.build()), (RowExpression)predicate), (Map)aggregations.build(), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable()));
    }

    private boolean shouldRewriteAggregation(AggregationNode.Aggregation aggregation, ProjectNode sourceProject) {
        if (this.functionAndTypeManager.getFunctionMetadata(aggregation.getFunctionHandle()).isCalledOnNullInput()) {
            return false;
        }
        if (aggregation.getArguments().size() != 1 || !(aggregation.getArguments().get(0) instanceof VariableReferenceExpression)) {
            return false;
        }
        if (aggregation.getFilter().isPresent() || aggregation.getMask().isPresent()) {
            return false;
        }
        RowExpression sourceExpression = sourceProject.getAssignments().get((VariableReferenceExpression)aggregation.getArguments().get(0));
        if (!(sourceExpression instanceof SpecialFormExpression) || !this.rowExpressionDeterminismEvaluator.isDeterministic(sourceExpression)) {
            return false;
        }
        SpecialFormExpression expression = (SpecialFormExpression)sourceExpression;
        return expression.getForm() == SpecialFormExpression.Form.IF && Expressions.isNull((RowExpression)expression.getArguments().get(2));
    }
}

