/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.execution.warnings.WarningCollector;
import com.facebook.presto.metadata.BuiltInFunctionNamespaceManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.FullyQualifiedName;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.MarkDistinctNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.IfExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

public class OptimizeMixedDistinctAggregations
implements PlanOptimizer {
    private final Metadata metadata;

    public OptimizeMixedDistinctAggregations(Metadata metadata) {
        this.metadata = metadata;
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, PlanVariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        if (SystemSessionProperties.isOptimizeDistinctAggregationEnabled(session)) {
            return SimplePlanRewriter.rewriteWith(new Optimizer(idAllocator, variableAllocator, this.metadata), plan, Optional.empty());
        }
        return plan;
    }

    private static class AggregateInfo {
        private final List<VariableReferenceExpression> groupByVariables;
        private final VariableReferenceExpression mask;
        private final Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregations;
        private final TypeProvider types;
        private Map<VariableReferenceExpression, VariableReferenceExpression> newNonDistinctAggregateVariables;
        private VariableReferenceExpression newDistinctAggregateVariable;
        private boolean foundMarkDistinct;

        public AggregateInfo(List<VariableReferenceExpression> groupByVariables, VariableReferenceExpression mask, Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregations, TypeProvider types) {
            this.groupByVariables = ImmutableList.copyOf(groupByVariables);
            this.mask = mask;
            this.aggregations = ImmutableMap.copyOf(aggregations);
            this.types = types;
        }

        public List<VariableReferenceExpression> getOriginalNonDistinctAggregateArgs() {
            return this.aggregations.values().stream().filter(aggregation -> !aggregation.getMask().isPresent()).flatMap(aggregation -> aggregation.getArguments().stream()).distinct().map(OriginalExpressionUtils::castToExpression).map(expression -> PlannerUtils.toVariableReference(expression, this.types)).collect(Collectors.toList());
        }

        public List<VariableReferenceExpression> getOriginalDistinctAggregateArgs() {
            return this.aggregations.values().stream().filter(aggregation -> aggregation.getMask().isPresent()).flatMap(aggregation -> aggregation.getArguments().stream()).distinct().map(expression -> PlannerUtils.toVariableReference(OriginalExpressionUtils.castToExpression(expression), this.types)).collect(Collectors.toList());
        }

        public VariableReferenceExpression getNewDistinctAggregateVariable() {
            return this.newDistinctAggregateVariable;
        }

        public void setNewDistinctAggregateSymbol(VariableReferenceExpression newDistinctAggregateVariable) {
            this.newDistinctAggregateVariable = newDistinctAggregateVariable;
        }

        public Map<VariableReferenceExpression, VariableReferenceExpression> getNewNonDistinctAggregateVariables() {
            return this.newNonDistinctAggregateVariables;
        }

        public void setNewNonDistinctAggregateSymbols(Map<VariableReferenceExpression, VariableReferenceExpression> newNonDistinctAggregateVariables) {
            this.newNonDistinctAggregateVariables = newNonDistinctAggregateVariables;
        }

        public VariableReferenceExpression getMask() {
            return this.mask;
        }

        public List<VariableReferenceExpression> getGroupByVariables() {
            return this.groupByVariables;
        }

        public Map<VariableReferenceExpression, AggregationNode.Aggregation> getAggregations() {
            return this.aggregations;
        }

        public void foundMarkDistinct() {
            this.foundMarkDistinct = true;
        }

        public boolean isFoundMarkDistinct() {
            return this.foundMarkDistinct;
        }
    }

    private static class Optimizer
    extends SimplePlanRewriter<Optional<AggregateInfo>> {
        private final PlanNodeIdAllocator idAllocator;
        private final PlanVariableAllocator variableAllocator;
        private final Metadata metadata;

        private Optimizer(PlanNodeIdAllocator idAllocator, PlanVariableAllocator variableAllocator, Metadata metadata) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.variableAllocator = Objects.requireNonNull(variableAllocator, "variableAllocator is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        }

        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> context) {
            List masks = (List)node.getAggregations().values().stream().map(AggregationNode.Aggregation::getMask).filter(Optional::isPresent).map(Optional::get).collect(ImmutableList.toImmutableList());
            ImmutableSet uniqueMasks = ImmutableSet.copyOf((Collection)masks);
            if (uniqueMasks.size() != 1 || masks.size() == node.getAggregations().size()) {
                return context.defaultRewrite((PlanNode)node, Optional.empty());
            }
            if (node.getAggregations().values().stream().map(AggregationNode.Aggregation::getFilter).anyMatch(Optional::isPresent)) {
                return context.defaultRewrite((PlanNode)node, Optional.empty());
            }
            if (node.hasOrderings()) {
                return context.defaultRewrite((PlanNode)node, Optional.empty());
            }
            AggregateInfo aggregateInfo = new AggregateInfo(node.getGroupingKeys(), (VariableReferenceExpression)Iterables.getOnlyElement((Iterable)uniqueMasks), node.getAggregations(), this.variableAllocator.getTypes());
            if (!this.checkAllEquatableTypes(aggregateInfo)) {
                return context.defaultRewrite((PlanNode)node, Optional.empty());
            }
            PlanNode source = context.rewrite(node.getSource(), Optional.of(aggregateInfo));
            if (!aggregateInfo.isFoundMarkDistinct()) {
                return context.defaultRewrite((PlanNode)node, Optional.empty());
            }
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            ImmutableMap.Builder coalesceVariablesBuilder = ImmutableMap.builder();
            for (Map.Entry entry : node.getAggregations().entrySet()) {
                if (((AggregationNode.Aggregation)entry.getValue()).getMask().isPresent()) {
                    VariableReferenceExpression input = aggregateInfo.getNewDistinctAggregateVariable();
                    aggregations.put(entry.getKey(), (Object)new AggregationNode.Aggregation(new CallExpression(((AggregationNode.Aggregation)entry.getValue()).getCall().getDisplayName(), ((AggregationNode.Aggregation)entry.getValue()).getCall().getFunctionHandle(), ((AggregationNode.Aggregation)entry.getValue()).getCall().getType(), (List)ImmutableList.of((Object)OriginalExpressionUtils.castToRowExpression((Expression)OriginalExpressionUtils.asSymbolReference(input)))), Optional.empty(), Optional.empty(), false, Optional.empty()));
                    continue;
                }
                VariableReferenceExpression argument = aggregateInfo.getNewNonDistinctAggregateVariables().get(entry.getKey());
                AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(new CallExpression("arbitrary", this.metadata.getFunctionManager().lookupFunction("arbitrary", TypeSignatureProvider.fromTypes((List<? extends Type>)ImmutableList.of((Object)argument.getType()))), ((VariableReferenceExpression)entry.getKey()).getType(), (List)ImmutableList.of((Object)OriginalExpressionUtils.castToRowExpression((Expression)OriginalExpressionUtils.asSymbolReference(argument)))), Optional.empty(), Optional.empty(), false, Optional.empty());
                FullyQualifiedName functionName = this.metadata.getFunctionManager().getFunctionMetadata(((AggregationNode.Aggregation)entry.getValue()).getFunctionHandle()).getName();
                if (functionName.equals((Object)FullyQualifiedName.of((FullyQualifiedName.Prefix)BuiltInFunctionNamespaceManager.DEFAULT_NAMESPACE, (String)"count")) || functionName.equals((Object)FullyQualifiedName.of((FullyQualifiedName.Prefix)BuiltInFunctionNamespaceManager.DEFAULT_NAMESPACE, (String)"count_if")) || functionName.equals((Object)FullyQualifiedName.of((FullyQualifiedName.Prefix)BuiltInFunctionNamespaceManager.DEFAULT_NAMESPACE, (String)"approx_distinct"))) {
                    VariableReferenceExpression newVariable = this.variableAllocator.newVariable("expr", ((VariableReferenceExpression)entry.getKey()).getType());
                    aggregations.put((Object)newVariable, (Object)aggregation);
                    coalesceVariablesBuilder.put((Object)newVariable, entry.getKey());
                    continue;
                }
                aggregations.put(entry.getKey(), (Object)aggregation);
            }
            ImmutableMap coalesceVariables = coalesceVariablesBuilder.build();
            AggregationNode aggregationNode = new AggregationNode(this.idAllocator.getNextId(), source, (Map)aggregations.build(), node.getGroupingSets(), (List)ImmutableList.of(), node.getStep(), Optional.empty(), node.getGroupIdVariable());
            if (coalesceVariables.isEmpty()) {
                return aggregationNode;
            }
            Assignments.Builder outputVariables = Assignments.builder();
            for (VariableReferenceExpression variable : aggregationNode.getOutputVariables()) {
                if (coalesceVariables.containsKey(variable)) {
                    CoalesceExpression expression = new CoalesceExpression((Expression)new SymbolReference(variable.getName()), (Expression)new Cast((Expression)new LongLiteral("0"), "bigint"), new Expression[0]);
                    outputVariables.put((VariableReferenceExpression)coalesceVariables.get(variable), OriginalExpressionUtils.castToRowExpression((Expression)expression));
                    continue;
                }
                outputVariables.put(AssignmentUtils.identityAsSymbolReference(variable));
            }
            return new ProjectNode(this.idAllocator.getNextId(), (PlanNode)aggregationNode, outputVariables.build());
        }

        @Override
        public PlanNode visitMarkDistinct(MarkDistinctNode node, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> context) {
            VariableReferenceExpression distinctVariable;
            Optional<AggregateInfo> aggregateInfo = context.get();
            if (!aggregateInfo.isPresent() || !aggregateInfo.get().getMask().equals((Object)node.getMarkerVariable())) {
                return context.defaultRewrite(node, Optional.empty());
            }
            aggregateInfo.get().foundMarkDistinct();
            PlanNode source = context.rewrite(node.getSource(), Optional.empty());
            HashSet<VariableReferenceExpression> allVariables = new HashSet<VariableReferenceExpression>();
            List<VariableReferenceExpression> groupByVariables = aggregateInfo.get().getGroupByVariables();
            List<VariableReferenceExpression> nonDistinctAggregateVariables = aggregateInfo.get().getOriginalNonDistinctAggregateArgs();
            VariableReferenceExpression duplicatedDistinctVariable = distinctVariable = (VariableReferenceExpression)Iterables.getOnlyElement(aggregateInfo.get().getOriginalDistinctAggregateArgs());
            if (nonDistinctAggregateVariables.contains(distinctVariable)) {
                VariableReferenceExpression newVariable = this.variableAllocator.newVariable(distinctVariable);
                nonDistinctAggregateVariables.set(nonDistinctAggregateVariables.indexOf(distinctVariable), newVariable);
                duplicatedDistinctVariable = newVariable;
            }
            allVariables.addAll(groupByVariables);
            allVariables.addAll(nonDistinctAggregateVariables);
            allVariables.add(distinctVariable);
            VariableReferenceExpression groupVariable = this.variableAllocator.newVariable("group", (Type)BigintType.BIGINT);
            GroupIdNode groupIdNode = this.createGroupIdNode(groupByVariables, nonDistinctAggregateVariables, distinctVariable, duplicatedDistinctVariable, groupVariable, allVariables, source);
            HashSet<VariableReferenceExpression> groupByKeys = new HashSet<VariableReferenceExpression>(groupByVariables);
            groupByKeys.add(distinctVariable);
            groupByKeys.add(groupVariable);
            ImmutableMap.Builder aggregationOutputVariablesMapBuilder = ImmutableMap.builder();
            AggregationNode aggregationNode = this.createNonDistinctAggregation(aggregateInfo.get(), distinctVariable, duplicatedDistinctVariable, groupByKeys, groupIdNode, node, (ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression>)aggregationOutputVariablesMapBuilder);
            ImmutableMap aggregationOutputVariablesMap = aggregationOutputVariablesMapBuilder.build();
            ProjectNode projectNode = this.createProjectNode(aggregationNode, aggregateInfo.get(), distinctVariable, groupVariable, groupByVariables, (Map<VariableReferenceExpression, VariableReferenceExpression>)aggregationOutputVariablesMap);
            return projectNode;
        }

        private boolean checkAllEquatableTypes(AggregateInfo aggregateInfo) {
            for (VariableReferenceExpression variable : aggregateInfo.getOriginalNonDistinctAggregateArgs()) {
                if (variable.getType().isComparable()) continue;
                return false;
            }
            return aggregateInfo.getMask().getType().isComparable();
        }

        private ProjectNode createProjectNode(AggregationNode source, AggregateInfo aggregateInfo, VariableReferenceExpression distinctVariable, VariableReferenceExpression groupVariable, List<VariableReferenceExpression> groupByVariables, Map<VariableReferenceExpression, VariableReferenceExpression> aggregationOutputVariablesMap) {
            Assignments.Builder outputVariables = Assignments.builder();
            ImmutableMap.Builder outputNonDistinctAggregateVariables = ImmutableMap.builder();
            for (VariableReferenceExpression variable : source.getOutputVariables()) {
                IfExpression expression;
                VariableReferenceExpression newVariable;
                if (distinctVariable.equals((Object)variable)) {
                    newVariable = this.variableAllocator.newVariable("expr", variable.getType());
                    aggregateInfo.setNewDistinctAggregateSymbol(newVariable);
                    expression = Optimizer.createIfExpression((Expression)new SymbolReference(groupVariable.getName()), (Expression)new Cast((Expression)new LongLiteral("1"), "bigint"), ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference(variable.getName()), variable.getType());
                    outputVariables.put(newVariable, OriginalExpressionUtils.castToRowExpression((Expression)expression));
                } else if (aggregationOutputVariablesMap.containsKey(variable)) {
                    newVariable = this.variableAllocator.newVariable("expr", variable.getType());
                    outputNonDistinctAggregateVariables.put((Object)aggregationOutputVariablesMap.get(variable), (Object)newVariable);
                    expression = Optimizer.createIfExpression((Expression)new SymbolReference(groupVariable.getName()), (Expression)new Cast((Expression)new LongLiteral("0"), "bigint"), ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference(variable.getName()), variable.getType());
                    outputVariables.put(newVariable, OriginalExpressionUtils.castToRowExpression((Expression)expression));
                }
                if (!groupByVariables.contains(variable)) continue;
                SymbolReference expression2 = new SymbolReference(variable.getName());
                outputVariables.put(variable, OriginalExpressionUtils.castToRowExpression((Expression)expression2));
            }
            outputVariables.put(aggregateInfo.getMask(), OriginalExpressionUtils.castToRowExpression((Expression)new NullLiteral()));
            aggregateInfo.setNewNonDistinctAggregateSymbols((Map<VariableReferenceExpression, VariableReferenceExpression>)outputNonDistinctAggregateVariables.build());
            return new ProjectNode(this.idAllocator.getNextId(), (PlanNode)source, outputVariables.build());
        }

        private GroupIdNode createGroupIdNode(List<VariableReferenceExpression> groupByVariables, List<VariableReferenceExpression> nonDistinctAggregateVariables, VariableReferenceExpression distinctVariable, VariableReferenceExpression duplicatedDistinctVariable, VariableReferenceExpression groupVariable, Set<VariableReferenceExpression> allVariables, PlanNode source) {
            ArrayList<List<VariableReferenceExpression>> groups = new ArrayList<List<VariableReferenceExpression>>();
            HashSet<VariableReferenceExpression> group0 = new HashSet<VariableReferenceExpression>();
            group0.addAll(groupByVariables);
            group0.addAll(nonDistinctAggregateVariables);
            groups.add((List<VariableReferenceExpression>)ImmutableList.copyOf(group0));
            HashSet<VariableReferenceExpression> group1 = new HashSet<VariableReferenceExpression>(groupByVariables);
            group1.add(distinctVariable);
            groups.add((List<VariableReferenceExpression>)ImmutableList.copyOf(group1));
            return new GroupIdNode(this.idAllocator.getNextId(), source, groups, allVariables.stream().collect(Collectors.toMap(Function.identity(), variable -> variable.equals((Object)duplicatedDistinctVariable) ? distinctVariable : variable)), (List<VariableReferenceExpression>)ImmutableList.of(), groupVariable);
        }

        private AggregationNode createNonDistinctAggregation(AggregateInfo aggregateInfo, VariableReferenceExpression distinctVariable, VariableReferenceExpression duplicatedDistinctVariable, Set<VariableReferenceExpression> groupByKeys, GroupIdNode groupIdNode, MarkDistinctNode originalNode, ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> aggregationOutputSymbolsMapBuilder) {
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : aggregateInfo.getAggregations().entrySet()) {
                List arguments;
                if (entry.getValue().getMask().isPresent()) continue;
                VariableReferenceExpression newVariable = this.variableAllocator.newVariable(entry.getKey());
                AggregationNode.Aggregation aggregation = entry.getValue();
                aggregationOutputSymbolsMapBuilder.put((Object)newVariable, (Object)entry.getKey());
                if (!duplicatedDistinctVariable.equals((Object)distinctVariable) && Optimizer.extractVariables(entry.getValue().getArguments(), this.variableAllocator.getTypes()).contains(distinctVariable)) {
                    ImmutableList.Builder argumentsBuilder = ImmutableList.builder();
                    for (RowExpression argument : aggregation.getArguments()) {
                        if (OriginalExpressionUtils.castToExpression(argument) instanceof SymbolReference && PlannerUtils.toVariableReference(OriginalExpressionUtils.castToExpression(argument), this.variableAllocator.getTypes()).equals((Object)distinctVariable)) {
                            argumentsBuilder.add((Object)OriginalExpressionUtils.castToRowExpression((Expression)OriginalExpressionUtils.asSymbolReference(duplicatedDistinctVariable)));
                            continue;
                        }
                        argumentsBuilder.add((Object)argument);
                    }
                    arguments = argumentsBuilder.build();
                } else {
                    arguments = aggregation.getArguments();
                }
                aggregations.put((Object)newVariable, (Object)new AggregationNode.Aggregation(new CallExpression(aggregation.getCall().getDisplayName(), aggregation.getCall().getFunctionHandle(), aggregation.getCall().getType(), arguments), Optional.empty(), Optional.empty(), false, Optional.empty()));
            }
            return new AggregationNode(this.idAllocator.getNextId(), (PlanNode)groupIdNode, (Map)aggregations.build(), AggregationNode.singleGroupingSet((List)ImmutableList.copyOf(groupByKeys)), (List)ImmutableList.of(), AggregationNode.Step.SINGLE, originalNode.getHashVariable(), Optional.empty());
        }

        private static Set<VariableReferenceExpression> extractVariables(List<RowExpression> arguments, TypeProvider types) {
            ImmutableSet.Builder builder = ImmutableSet.builder();
            for (RowExpression argument : arguments) {
                Expression expression = OriginalExpressionUtils.castToExpression(argument);
                if (!(expression instanceof SymbolReference)) continue;
                builder.add((Object)Expressions.variable(((SymbolReference)expression).getName(), types.get(expression)));
            }
            return builder.build();
        }

        private static IfExpression createIfExpression(Expression left, Expression right, ComparisonExpression.Operator operator, Expression result, Type trueValueType) {
            return new IfExpression((Expression)new ComparisonExpression(operator, left, right), result, (Expression)new Cast((Expression)new NullLiteral(), trueValueType.getTypeSignature().toString()));
        }
    }
}

