/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.MarkDistinctNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.IfExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class OptimizeMixedDistinctAggregations
implements PlanOptimizer {
    private final Metadata metadata;

    public OptimizeMixedDistinctAggregations(Metadata metadata) {
        this.metadata = metadata;
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        if (SystemSessionProperties.isOptimizeDistinctAggregationEnabled(session)) {
            return SimplePlanRewriter.rewriteWith(new Optimizer(idAllocator, symbolAllocator, this.metadata), plan, Optional.empty());
        }
        return plan;
    }

    private static class AggregateInfo {
        private final List<Symbol> groupBySymbols;
        private final Symbol mask;
        private final Map<Symbol, AggregationNode.Aggregation> aggregations;
        private Map<Symbol, Symbol> newNonDistinctAggregateSymbols;
        private Symbol newDistinctAggregateSymbol;
        private boolean foundMarkDistinct;

        public AggregateInfo(List<Symbol> groupBySymbols, Symbol mask, Map<Symbol, AggregationNode.Aggregation> aggregations) {
            this.groupBySymbols = ImmutableList.copyOf(groupBySymbols);
            this.mask = mask;
            this.aggregations = ImmutableMap.copyOf(aggregations);
        }

        public List<Symbol> getOriginalNonDistinctAggregateArgs() {
            return this.aggregations.values().stream().map(AggregationNode.Aggregation::getCall).filter(function -> !function.isDistinct()).flatMap(function -> function.getArguments().stream()).distinct().map(Symbol::from).collect(Collectors.toList());
        }

        public List<Symbol> getOriginalDistinctAggregateArgs() {
            return this.aggregations.values().stream().map(AggregationNode.Aggregation::getCall).filter(FunctionCall::isDistinct).flatMap(function -> function.getArguments().stream()).distinct().map(Symbol::from).collect(Collectors.toList());
        }

        public Symbol getNewDistinctAggregateSymbol() {
            return this.newDistinctAggregateSymbol;
        }

        public void setNewDistinctAggregateSymbol(Symbol newDistinctAggregateSymbol) {
            this.newDistinctAggregateSymbol = newDistinctAggregateSymbol;
        }

        public Map<Symbol, Symbol> getNewNonDistinctAggregateSymbols() {
            return this.newNonDistinctAggregateSymbols;
        }

        public void setNewNonDistinctAggregateSymbols(Map<Symbol, Symbol> newNonDistinctAggregateSymbols) {
            this.newNonDistinctAggregateSymbols = newNonDistinctAggregateSymbols;
        }

        public Symbol getMask() {
            return this.mask;
        }

        public List<Symbol> getGroupBySymbols() {
            return this.groupBySymbols;
        }

        public Map<Symbol, AggregationNode.Aggregation> getAggregations() {
            return this.aggregations;
        }

        public void foundMarkDistinct() {
            this.foundMarkDistinct = true;
        }

        public boolean isFoundMarkDistinct() {
            return this.foundMarkDistinct;
        }
    }

    private static class Optimizer
    extends SimplePlanRewriter<Optional<AggregateInfo>> {
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;
        private final Metadata metadata;

        private Optimizer(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Metadata metadata) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> context) {
            List masks = (List)node.getAggregations().values().stream().map(AggregationNode.Aggregation::getMask).filter(Optional::isPresent).map(Optional::get).collect(ImmutableList.toImmutableList());
            ImmutableSet uniqueMasks = ImmutableSet.copyOf((Collection)masks);
            if (uniqueMasks.size() != 1 || masks.size() == node.getAggregations().size()) {
                return context.defaultRewrite(node, Optional.empty());
            }
            if (node.getAggregations().values().stream().map(AggregationNode.Aggregation::getCall).map(FunctionCall::getFilter).anyMatch(Optional::isPresent)) {
                return context.defaultRewrite(node, Optional.empty());
            }
            AggregateInfo aggregateInfo = new AggregateInfo(node.getGroupingKeys(), (Symbol)Iterables.getOnlyElement((Iterable)uniqueMasks), node.getAggregations());
            if (!this.checkAllEquatableTypes(aggregateInfo)) {
                return context.defaultRewrite(node, Optional.empty());
            }
            PlanNode source = context.rewrite(node.getSource(), Optional.of(aggregateInfo));
            if (!aggregateInfo.isFoundMarkDistinct()) {
                return context.defaultRewrite(node, Optional.empty());
            }
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
                FunctionCall functionCall = entry.getValue().getCall();
                if (functionCall.isDistinct()) {
                    aggregations.put((Object)entry.getKey(), (Object)new AggregationNode.Aggregation(new FunctionCall(functionCall.getName(), functionCall.getWindow(), false, (List)ImmutableList.of((Object)aggregateInfo.getNewDistinctAggregateSymbol().toSymbolReference())), entry.getValue().getSignature(), Optional.empty()));
                    continue;
                }
                Symbol argument = aggregateInfo.getNewNonDistinctAggregateSymbols().get(entry.getKey());
                QualifiedName functionName = QualifiedName.of((String)"arbitrary");
                aggregations.put((Object)entry.getKey(), (Object)new AggregationNode.Aggregation(new FunctionCall(functionName, functionCall.getWindow(), false, (List)ImmutableList.of((Object)argument.toSymbolReference())), this.getFunctionSignature(functionName, argument), Optional.empty()));
            }
            return new AggregationNode(this.idAllocator.getNextId(), source, (Map<Symbol, AggregationNode.Aggregation>)aggregations.build(), node.getGroupingSets(), node.getStep(), Optional.empty(), node.getGroupIdSymbol());
        }

        @Override
        public PlanNode visitMarkDistinct(MarkDistinctNode node, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> context) {
            Symbol distinctSymbol;
            Optional<AggregateInfo> aggregateInfo = context.get();
            if (!aggregateInfo.isPresent() || !aggregateInfo.get().getMask().equals(node.getMarkerSymbol())) {
                return context.defaultRewrite(node, Optional.empty());
            }
            aggregateInfo.get().foundMarkDistinct();
            PlanNode source = context.rewrite(node.getSource(), Optional.empty());
            HashSet<Symbol> allSymbols = new HashSet<Symbol>();
            List<Symbol> groupBySymbols = aggregateInfo.get().getGroupBySymbols();
            List<Symbol> nonDistinctAggregateSymbols = aggregateInfo.get().getOriginalNonDistinctAggregateArgs();
            Symbol duplicatedDistinctSymbol = distinctSymbol = (Symbol)Iterables.getOnlyElement(aggregateInfo.get().getOriginalDistinctAggregateArgs());
            if (nonDistinctAggregateSymbols.contains(distinctSymbol)) {
                Symbol newSymbol = this.symbolAllocator.newSymbol(distinctSymbol.getName(), this.symbolAllocator.getTypes().get(distinctSymbol));
                nonDistinctAggregateSymbols.set(nonDistinctAggregateSymbols.indexOf(distinctSymbol), newSymbol);
                duplicatedDistinctSymbol = newSymbol;
            }
            allSymbols.addAll(groupBySymbols);
            allSymbols.addAll(nonDistinctAggregateSymbols);
            allSymbols.add(distinctSymbol);
            Symbol groupSymbol = this.symbolAllocator.newSymbol("group", (Type)BigintType.BIGINT);
            GroupIdNode groupIdNode = this.createGroupIdNode(groupBySymbols, nonDistinctAggregateSymbols, distinctSymbol, duplicatedDistinctSymbol, groupSymbol, allSymbols, source);
            HashSet<Symbol> groupByKeys = new HashSet<Symbol>();
            groupByKeys.addAll(groupBySymbols);
            groupByKeys.add(distinctSymbol);
            groupByKeys.add(groupSymbol);
            ImmutableMap.Builder aggregationOutputSymbolsMapBuilder = ImmutableMap.builder();
            AggregationNode aggregationNode = this.createNonDistinctAggregation(aggregateInfo.get(), distinctSymbol, duplicatedDistinctSymbol, groupByKeys, groupIdNode, node, (ImmutableMap.Builder<Symbol, Symbol>)aggregationOutputSymbolsMapBuilder);
            ImmutableMap aggregationOutputSymbolsMap = aggregationOutputSymbolsMapBuilder.build();
            ProjectNode projectNode = this.createProjectNode(aggregationNode, aggregateInfo.get(), distinctSymbol, groupSymbol, groupBySymbols, (Map<Symbol, Symbol>)aggregationOutputSymbolsMap);
            return projectNode;
        }

        private boolean checkAllEquatableTypes(AggregateInfo aggregateInfo) {
            for (Symbol symbol : aggregateInfo.getOriginalNonDistinctAggregateArgs()) {
                Type type = this.symbolAllocator.getTypes().get(symbol);
                if (type.isComparable()) continue;
                return false;
            }
            return this.symbolAllocator.getTypes().get(aggregateInfo.getMask()).isComparable();
        }

        private ProjectNode createProjectNode(AggregationNode source, AggregateInfo aggregateInfo, Symbol distinctSymbol, Symbol groupSymbol, List<Symbol> groupBySymbols, Map<Symbol, Symbol> aggregationOutputSymbolsMap) {
            Assignments.Builder outputSymbols = Assignments.builder();
            ImmutableMap.Builder outputNonDistinctAggregateSymbols = ImmutableMap.builder();
            for (Symbol symbol : source.getOutputSymbols()) {
                IfExpression expression;
                Symbol newSymbol;
                if (distinctSymbol.equals(symbol)) {
                    newSymbol = this.symbolAllocator.newSymbol("expr", this.symbolAllocator.getTypes().get(symbol));
                    aggregateInfo.setNewDistinctAggregateSymbol(newSymbol);
                    expression = Optimizer.createIfExpression((Expression)groupSymbol.toSymbolReference(), (Expression)new Cast((Expression)new LongLiteral("1"), "bigint"), ComparisonExpressionType.EQUAL, (Expression)symbol.toSymbolReference(), this.symbolAllocator.getTypes().get(symbol));
                    outputSymbols.put(newSymbol, (Expression)expression);
                } else if (aggregationOutputSymbolsMap.containsKey(symbol)) {
                    newSymbol = this.symbolAllocator.newSymbol("expr", this.symbolAllocator.getTypes().get(symbol));
                    outputNonDistinctAggregateSymbols.put((Object)aggregationOutputSymbolsMap.get(symbol), (Object)newSymbol);
                    expression = Optimizer.createIfExpression((Expression)groupSymbol.toSymbolReference(), (Expression)new Cast((Expression)new LongLiteral("0"), "bigint"), ComparisonExpressionType.EQUAL, (Expression)symbol.toSymbolReference(), this.symbolAllocator.getTypes().get(symbol));
                    outputSymbols.put(newSymbol, (Expression)expression);
                }
                if (!groupBySymbols.contains(symbol)) continue;
                SymbolReference expression2 = symbol.toSymbolReference();
                outputSymbols.put(symbol, (Expression)expression2);
            }
            outputSymbols.put(aggregateInfo.getMask(), (Expression)new NullLiteral());
            aggregateInfo.setNewNonDistinctAggregateSymbols((Map<Symbol, Symbol>)outputNonDistinctAggregateSymbols.build());
            return new ProjectNode(this.idAllocator.getNextId(), source, outputSymbols.build());
        }

        private GroupIdNode createGroupIdNode(List<Symbol> groupBySymbols, List<Symbol> nonDistinctAggregateSymbols, Symbol distinctSymbol, Symbol duplicatedDistinctSymbol, Symbol groupSymbol, Set<Symbol> allSymbols, PlanNode source) {
            ArrayList<List<Symbol>> groups = new ArrayList<List<Symbol>>();
            HashSet<Symbol> group0 = new HashSet<Symbol>();
            group0.addAll(groupBySymbols);
            group0.addAll(nonDistinctAggregateSymbols);
            groups.add((List<Symbol>)ImmutableList.copyOf(group0));
            HashSet<Symbol> group1 = new HashSet<Symbol>();
            group1.addAll(groupBySymbols);
            group1.add(distinctSymbol);
            groups.add((List<Symbol>)ImmutableList.copyOf(group1));
            return new GroupIdNode(this.idAllocator.getNextId(), source, groups, allSymbols.stream().collect(Collectors.toMap(symbol -> symbol, symbol -> symbol.equals(duplicatedDistinctSymbol) ? distinctSymbol : symbol)), (Map<Symbol, Symbol>)ImmutableMap.of(), groupSymbol);
        }

        private AggregationNode createNonDistinctAggregation(AggregateInfo aggregateInfo, Symbol distinctSymbol, Symbol duplicatedDistinctSymbol, Set<Symbol> groupByKeys, GroupIdNode groupIdNode, MarkDistinctNode originalNode, ImmutableMap.Builder<Symbol, Symbol> aggregationOutputSymbolsMapBuilder) {
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregateInfo.getAggregations().entrySet()) {
                FunctionCall functionCall = entry.getValue().getCall();
                if (functionCall.isDistinct()) continue;
                Symbol newSymbol = this.symbolAllocator.newSymbol((Expression)entry.getKey().toSymbolReference(), this.symbolAllocator.getTypes().get(entry.getKey()));
                aggregationOutputSymbolsMapBuilder.put((Object)newSymbol, (Object)entry.getKey());
                if (!duplicatedDistinctSymbol.equals(distinctSymbol) && functionCall.getArguments().contains(distinctSymbol.toSymbolReference())) {
                    ImmutableList.Builder arguments = ImmutableList.builder();
                    for (Expression argument : functionCall.getArguments()) {
                        if (distinctSymbol.toSymbolReference().equals((Object)argument)) {
                            arguments.add((Object)duplicatedDistinctSymbol.toSymbolReference());
                            continue;
                        }
                        arguments.add((Object)argument);
                    }
                    functionCall = new FunctionCall(functionCall.getName(), functionCall.getWindow(), false, (List)arguments.build());
                }
                aggregations.put((Object)newSymbol, (Object)new AggregationNode.Aggregation(functionCall, entry.getValue().getSignature(), Optional.empty()));
            }
            return new AggregationNode(this.idAllocator.getNextId(), groupIdNode, (Map<Symbol, AggregationNode.Aggregation>)aggregations.build(), (List<List<Symbol>>)ImmutableList.of((Object)ImmutableList.copyOf(groupByKeys)), AggregationNode.Step.SINGLE, originalNode.getHashSymbol(), Optional.empty());
        }

        private Signature getFunctionSignature(QualifiedName functionName, Symbol argument) {
            return this.metadata.getFunctionRegistry().resolveFunction(functionName, (List<TypeSignatureProvider>)ImmutableList.of((Object)new TypeSignatureProvider(this.symbolAllocator.getTypes().get(argument).getTypeSignature())));
        }

        private static IfExpression createIfExpression(Expression left, Expression right, ComparisonExpressionType type, Expression result, Type trueValueType) {
            return new IfExpression((Expression)new ComparisonExpression(type, left, right), result, (Expression)new Cast((Expression)new NullLiteral(), trueValueType.getTypeSignature().toString()));
        }
    }
}

