package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.google.common.collect.ImmutableMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/ImplementFilteredAggregations.class */
public class ImplementFilteredAggregations implements PlanOptimizer {

    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/ImplementFilteredAggregations$Optimizer.class */
    private static class Optimizer extends SimplePlanRewriter<Optional<Symbol>> {
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;

        private Optimizer(PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator) {
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
            this.symbolAllocator = (SymbolAllocator) Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitAggregation(AggregationNode aggregationNode, SimplePlanRewriter.RewriteContext<Optional<Symbol>> rewriteContext) {
            boolean z = false;
            for (Map.Entry<Symbol, FunctionCall> entry : aggregationNode.getAggregations().entrySet()) {
                Symbol key = entry.getKey();
                if (entry.getValue().getFilter().isPresent()) {
                    if (aggregationNode.getMasks().containsKey(key)) {
                        return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
                    }
                    z = true;
                }
            }
            if (!z) {
                return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
            }
            ImmutableMap.Builder builder = ImmutableMap.builder();
            ImmutableMap.Builder putAll = ImmutableMap.builder().putAll(aggregationNode.getMasks());
            for (Map.Entry<Symbol, FunctionCall> entry2 : aggregationNode.getAggregations().entrySet()) {
                Symbol key2 = entry2.getKey();
                if (entry2.getValue().getFilter().isPresent()) {
                    Expression expression = entry2.getValue().getFilter().get();
                    Symbol newSymbol = this.symbolAllocator.newSymbol(expression, BooleanType.BOOLEAN);
                    builder.put(newSymbol, expression);
                    putAll.put(key2, newSymbol);
                }
            }
            for (Symbol symbol : aggregationNode.getSource().getOutputSymbols()) {
                builder.put(symbol, symbol.toSymbolReference());
            }
            ImmutableMap.Builder builder2 = ImmutableMap.builder();
            for (Map.Entry<Symbol, FunctionCall> entry3 : aggregationNode.getAggregations().entrySet()) {
                FunctionCall value = entry3.getValue();
                builder2.put(entry3.getKey(), new FunctionCall(value.getName(), value.getWindow(), Optional.empty(), value.isDistinct(), value.getArguments()));
            }
            return new AggregationNode(this.idAllocator.getNextId(), new ProjectNode(this.idAllocator.getNextId(), aggregationNode.getSource(), builder.build()), builder2.build(), aggregationNode.getFunctions(), putAll.build(), aggregationNode.getGroupingSets(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol());
        }
    }

    @Override // com.facebook.presto.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, Map<Symbol, Type> map, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator) {
        return SimplePlanRewriter.rewriteWith(new Optimizer(planNodeIdAllocator, symbolAllocator), planNode, Optional.empty());
    }
}
