/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.google.common.collect.ImmutableMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class ImplementFilteredAggregations
implements PlanOptimizer {
    @Override
    public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        return SimplePlanRewriter.rewriteWith(new Optimizer(idAllocator, symbolAllocator), plan, Optional.empty());
    }

    private static class Optimizer
    extends SimplePlanRewriter<Optional<Symbol>> {
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;

        private Optimizer(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<Optional<Symbol>> context) {
            boolean hasFilters = false;
            for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) {
                Iterator<Symbol> output = entry.getKey();
                FunctionCall functionCall = entry.getValue();
                if (!functionCall.getFilter().isPresent()) continue;
                if (node.getMasks().containsKey(output)) {
                    return context.defaultRewrite(node, Optional.empty());
                }
                hasFilters = true;
            }
            if (!hasFilters) {
                return context.defaultRewrite(node, Optional.empty());
            }
            ImmutableMap.Builder newProjections = ImmutableMap.builder();
            ImmutableMap.Builder masks = ImmutableMap.builder().putAll(node.getMasks());
            for (Map.Entry entry : node.getAggregations().entrySet()) {
                Symbol output = (Symbol)entry.getKey();
                if (!((FunctionCall)entry.getValue()).getFilter().isPresent()) continue;
                Expression filter = (Expression)((FunctionCall)entry.getValue()).getFilter().get();
                Symbol symbol = this.symbolAllocator.newSymbol(filter, (Type)BooleanType.BOOLEAN);
                newProjections.put((Object)symbol, (Object)filter);
                masks.put((Object)output, (Object)symbol);
            }
            for (Symbol symbol : node.getSource().getOutputSymbols()) {
                newProjections.put((Object)symbol, (Object)symbol.toSymbolReference());
            }
            ImmutableMap.Builder calls = ImmutableMap.builder();
            for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) {
                FunctionCall call = entry.getValue();
                calls.put((Object)entry.getKey(), (Object)new FunctionCall(call.getName(), call.getWindow(), Optional.empty(), call.isDistinct(), call.getArguments()));
            }
            return new AggregationNode(this.idAllocator.getNextId(), new ProjectNode(this.idAllocator.getNextId(), node.getSource(), (Map<Symbol, Expression>)newProjections.build()), (Map<Symbol, FunctionCall>)calls.build(), node.getFunctions(), (Map<Symbol, Symbol>)masks.build(), node.getGroupingSets(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol());
        }
    }
}

