package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Pattern;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.optimizations.Predicates;
import com.facebook.presto.sql.planner.optimizations.ScalarAggregationToJoinRewriter;
import com.facebook.presto.sql.planner.optimizations.ScalarQueryUtil;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.LateralJoinNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.class */
public class TransformCorrelatedScalarAggregationToJoin implements Rule {
    private static final Pattern PATTERN = Pattern.node(LateralJoinNode.class);
    private final FunctionRegistry functionRegistry;

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern getPattern() {
        return PATTERN;
    }

    public TransformCorrelatedScalarAggregationToJoin(FunctionRegistry functionRegistry) {
        this.functionRegistry = (FunctionRegistry) Objects.requireNonNull(functionRegistry, "functionRegistry is null");
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Optional<PlanNode> apply(PlanNode planNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator, Session session) {
        if (!(planNode instanceof LateralJoinNode)) {
            return Optional.empty();
        }
        LateralJoinNode lateralJoinNode = (LateralJoinNode) planNode;
        PlanNode resolve = lookup.resolve(lateralJoinNode.getSubquery());
        if (lateralJoinNode.getCorrelation().isEmpty() || !ScalarQueryUtil.isScalar(resolve, lookup)) {
            return Optional.empty();
        }
        Optional<AggregationNode> findAggregation = findAggregation(resolve, lookup);
        if (!findAggregation.isPresent() || !findAggregation.get().getGroupingKeys().isEmpty()) {
            return Optional.empty();
        }
        PlanNode rewriteScalarAggregation = new ScalarAggregationToJoinRewriter(this.functionRegistry, symbolAllocator, planNodeIdAllocator, lookup).rewriteScalarAggregation(lateralJoinNode, findAggregation.get());
        return rewriteScalarAggregation instanceof LateralJoinNode ? Optional.empty() : Optional.of(rewriteScalarAggregation);
    }

    private static Optional<AggregationNode> findAggregation(PlanNode planNode, Lookup lookup) {
        PlanNodeSearcher searchFrom = PlanNodeSearcher.searchFrom(planNode, lookup);
        Class<AggregationNode> cls = AggregationNode.class;
        AggregationNode.class.getClass();
        return searchFrom.where((v1) -> {
            return r1.isInstance(v1);
        }).skipOnlyWhen(Predicates.isInstanceOfAny(ProjectNode.class, EnforceSingleRowNode.class)).findFirst();
    }
}
