package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.hive.$internal.org.apache.hadoop.fs.shell.Count;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.optimizations.PlanNodeDecorrelator;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LateralJoinNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.class */
public class ScalarAggregationToJoinRewriter {
    private static final QualifiedName COUNT = QualifiedName.of(Count.NAME);
    private final FunctionRegistry functionRegistry;
    private final SymbolAllocator symbolAllocator;
    private final PlanNodeIdAllocator idAllocator;
    private final Lookup lookup;
    private final PlanNodeDecorrelator planNodeDecorrelator;

    public ScalarAggregationToJoinRewriter(FunctionRegistry functionRegistry, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, Lookup lookup) {
        this.functionRegistry = (FunctionRegistry) Objects.requireNonNull(functionRegistry, "metadata is null");
        this.symbolAllocator = (SymbolAllocator) Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
        this.planNodeDecorrelator = new PlanNodeDecorrelator(planNodeIdAllocator, lookup);
    }

    public PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, AggregationNode aggregationNode) {
        Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelateFilters = this.planNodeDecorrelator.decorrelateFilters(this.lookup.resolve(aggregationNode.getSource()), lateralJoinNode.getCorrelation());
        if (!decorrelateFilters.isPresent()) {
            return lateralJoinNode;
        }
        Symbol newSymbol = this.symbolAllocator.newSymbol("non_null", BooleanType.BOOLEAN);
        return rewriteScalarAggregation(lateralJoinNode, aggregationNode, new ProjectNode(this.idAllocator.getNextId(), decorrelateFilters.get().getNode(), Assignments.builder().putIdentities(decorrelateFilters.get().getNode().getOutputSymbols()).put(newSymbol, BooleanLiteral.TRUE_LITERAL).build()), decorrelateFilters.get().getCorrelatedPredicates(), newSymbol);
    }

    private PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, AggregationNode aggregationNode, PlanNode planNode, Optional<Expression> optional, Symbol symbol) {
        AssignUniqueId assignUniqueId = new AssignUniqueId(this.idAllocator.getNextId(), lateralJoinNode.getInput(), this.symbolAllocator.newSymbol("unique", BigintType.BIGINT));
        Optional<AggregationNode> createAggregationNode = createAggregationNode(aggregationNode, new JoinNode(this.idAllocator.getNextId(), JoinNode.Type.LEFT, assignUniqueId, planNode, ImmutableList.of(), ImmutableList.builder().addAll((Iterable) assignUniqueId.getOutputSymbols()).addAll((Iterable) planNode.getOutputSymbols()).build(), optional, Optional.empty(), Optional.empty(), Optional.empty()), symbol);
        if (!createAggregationNode.isPresent()) {
            return lateralJoinNode;
        }
        PlanNodeSearcher searchFrom = PlanNodeSearcher.searchFrom(lateralJoinNode.getSubquery(), this.lookup);
        Class<ProjectNode> cls = ProjectNode.class;
        ProjectNode.class.getClass();
        PlanNodeSearcher where = searchFrom.where((v1) -> {
            return r1.isInstance(v1);
        });
        Class<EnforceSingleRowNode> cls2 = EnforceSingleRowNode.class;
        EnforceSingleRowNode.class.getClass();
        Optional findFirst = where.recurseOnlyWhen((v1) -> {
            return r1.isInstance(v1);
        }).findFirst();
        List<Symbol> truncatedAggregationSymbols = getTruncatedAggregationSymbols(lateralJoinNode, createAggregationNode.get());
        if (!findFirst.isPresent()) {
            return new ProjectNode(this.idAllocator.getNextId(), createAggregationNode.get(), Assignments.identity(truncatedAggregationSymbols));
        }
        return new ProjectNode(this.idAllocator.getNextId(), createAggregationNode.get(), Assignments.builder().putIdentities(truncatedAggregationSymbols).putAll(((ProjectNode) findFirst.get()).getAssignments()).build());
    }

    private static List<Symbol> getTruncatedAggregationSymbols(LateralJoinNode lateralJoinNode, AggregationNode aggregationNode) {
        HashSet hashSet = new HashSet(lateralJoinNode.getOutputSymbols());
        return (List) aggregationNode.getOutputSymbols().stream().filter(symbol -> {
            return hashSet.contains(symbol);
        }).collect(ImmutableList.toImmutableList());
    }

    private Optional<AggregationNode> createAggregationNode(AggregationNode aggregationNode, JoinNode joinNode, Symbol symbol) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            FunctionCall call = entry.getValue().getCall();
            Symbol key = entry.getKey();
            if (call.getName().equals(COUNT)) {
                builder.put(key, new AggregationNode.Aggregation(new FunctionCall(COUNT, ImmutableList.of(symbol.toSymbolReference())), this.functionRegistry.resolveFunction(COUNT, TypeSignatureProvider.fromTypeSignatures(ImmutableList.of(this.symbolAllocator.getTypes().get(symbol).getTypeSignature()))), entry.getValue().getMask()));
            } else {
                builder.put(key, entry.getValue());
            }
        }
        return Optional.of(new AggregationNode(this.idAllocator.getNextId(), joinNode, builder.build(), ImmutableList.of(joinNode.getLeft().getOutputSymbols()), aggregationNode.getStep(), aggregationNode.getHashSymbol(), Optional.empty()));
    }
}
