package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.hive.$internal.org.apache.hadoop.fs.shell.Count;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.DependencyExtractor;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.DefaultTraversalVisitor;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.util.ImmutableCollectors;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/TransformCorrelatedScalarAggregationToJoin.class */
public class TransformCorrelatedScalarAggregationToJoin implements PlanOptimizer {
    private final Metadata metadata;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/TransformCorrelatedScalarAggregationToJoin$DecorrelatedNode.class */
    public static class DecorrelatedNode {
        private final List<Expression> correlatedPredicates;
        private final PlanNode node;

        public DecorrelatedNode(List<Expression> list, PlanNode planNode) {
            Objects.requireNonNull(list, "correlatedPredicates is null");
            this.correlatedPredicates = ImmutableList.copyOf((Collection) list);
            this.node = (PlanNode) Objects.requireNonNull(planNode, "node is null");
        }

        Optional<Expression> getCorrelatedPredicates() {
            return this.correlatedPredicates.isEmpty() ? Optional.empty() : Optional.of(ExpressionUtils.and(this.correlatedPredicates));
        }

        public PlanNode getNode() {
            return this.node;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/TransformCorrelatedScalarAggregationToJoin$ExtendProjectionRewriter.class */
    public static class ExtendProjectionRewriter extends SimplePlanRewriter<PlanNode> {
        private final PlanNodeIdAllocator idAllocator;
        private final Set<Symbol> symbols;

        ExtendProjectionRewriter(PlanNodeIdAllocator planNodeIdAllocator, Set<Symbol> set) {
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
            this.symbols = (Set) Objects.requireNonNull(set, "symbols is null");
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitProject(ProjectNode projectNode, SimplePlanRewriter.RewriteContext<PlanNode> rewriteContext) {
            ProjectNode projectNode2 = (ProjectNode) rewriteContext.defaultRewrite(projectNode, rewriteContext.get());
            Stream<Symbol> stream = this.symbols.stream();
            List<Symbol> outputSymbols = projectNode2.getSource().getOutputSymbols();
            outputSymbols.getClass();
            return new ProjectNode(this.idAllocator.getNextId(), projectNode2.getSource(), ImmutableMap.builder().putAll(projectNode2.getAssignments()).putAll(TransformCorrelatedScalarAggregationToJoin.toAssignments((List) stream.filter((v1) -> {
                return r1.contains(v1);
            }).filter(symbol -> {
                return !projectNode2.getOutputSymbols().contains(symbol);
            }).collect(ImmutableCollectors.toImmutableList()))).build());
        }
    }

    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/TransformCorrelatedScalarAggregationToJoin$Rewriter.class */
    private static class Rewriter extends SimplePlanRewriter<PlanNode> {
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;
        private final Metadata metadata;

        public Rewriter(PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator, Metadata metadata) {
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
            this.symbolAllocator = (SymbolAllocator) Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitApply(ApplyNode applyNode, SimplePlanRewriter.RewriteContext<PlanNode> rewriteContext) {
            ApplyNode applyNode2 = (ApplyNode) rewriteContext.defaultRewrite(applyNode, rewriteContext.get());
            if (!applyNode2.getCorrelation().isEmpty() && applyNode2.isResolvedScalarSubquery()) {
                PlanNodeSearcher searchFrom = PlanNodeSearcher.searchFrom(applyNode2.getSubquery());
                Class<AggregationNode> cls = AggregationNode.class;
                AggregationNode.class.getClass();
                Optional findFirst = searchFrom.where((v1) -> {
                    return r1.isInstance(v1);
                }).skipOnlyWhen(Predicates.isInstanceOfAny(ProjectNode.class, EnforceSingleRowNode.class)).findFirst();
                if (findFirst.isPresent() && ((AggregationNode) findFirst.get()).getGroupingKeys().isEmpty()) {
                    return rewriteScalarAggregation(applyNode2, (AggregationNode) findFirst.get());
                }
            }
            return applyNode2;
        }

        private PlanNode rewriteScalarAggregation(ApplyNode applyNode, AggregationNode aggregationNode) {
            Optional<DecorrelatedNode> decorrelateFilters = decorrelateFilters(aggregationNode.getSource(), applyNode.getCorrelation());
            if (!decorrelateFilters.isPresent()) {
                return applyNode;
            }
            Symbol newSymbol = this.symbolAllocator.newSymbol("non_null", BooleanType.BOOLEAN);
            return rewriteScalarAggregation(applyNode, aggregationNode, new ProjectNode(this.idAllocator.getNextId(), decorrelateFilters.get().getNode(), ImmutableMap.builder().putAll(TransformCorrelatedScalarAggregationToJoin.toAssignments(decorrelateFilters.get().getNode().getOutputSymbols())).put(newSymbol, BooleanLiteral.TRUE_LITERAL).build()), decorrelateFilters.get().getCorrelatedPredicates(), newSymbol);
        }

        private PlanNode rewriteScalarAggregation(ApplyNode applyNode, AggregationNode aggregationNode, PlanNode planNode, Optional<Expression> optional, Symbol symbol) {
            Optional<AggregationNode> createAggregationNode = createAggregationNode(aggregationNode, new JoinNode(this.idAllocator.getNextId(), JoinNode.Type.LEFT, new AssignUniqueId(this.idAllocator.getNextId(), applyNode.getInput(), this.symbolAllocator.newSymbol("unique", BigintType.BIGINT)), planNode, ImmutableList.of(), optional, Optional.empty(), Optional.empty()), symbol);
            if (!createAggregationNode.isPresent()) {
                return applyNode;
            }
            PlanNodeSearcher searchFrom = PlanNodeSearcher.searchFrom(applyNode.getSubquery());
            Class<ProjectNode> cls = ProjectNode.class;
            ProjectNode.class.getClass();
            Optional findFirst = searchFrom.where((v1) -> {
                return r1.isInstance(v1);
            }).findFirst();
            if (!findFirst.isPresent()) {
                return createAggregationNode.get();
            }
            return new ProjectNode(this.idAllocator.getNextId(), createAggregationNode.get(), ImmutableMap.builder().putAll(TransformCorrelatedScalarAggregationToJoin.toAssignments(createAggregationNode.get().getOutputSymbols())).putAll(((ProjectNode) findFirst.get()).getAssignments()).build());
        }

        private Optional<AggregationNode> createAggregationNode(AggregationNode aggregationNode, JoinNode joinNode, Symbol symbol) {
            ImmutableMap.Builder builder = ImmutableMap.builder();
            ImmutableMap.Builder builder2 = ImmutableMap.builder();
            FunctionRegistry functionRegistry = this.metadata.getFunctionRegistry();
            for (Map.Entry<Symbol, FunctionCall> entry : aggregationNode.getAggregations().entrySet()) {
                FunctionCall value = entry.getValue();
                QualifiedName of = QualifiedName.of(Count.NAME);
                Symbol key = entry.getKey();
                if (value.getName().equals(of)) {
                    builder.put(key, new FunctionCall(of, ImmutableList.of(symbol.toSymbolReference())));
                    builder2.put(key, functionRegistry.resolveFunction(of, TypeSignatureProvider.fromTypeSignatures(ImmutableList.of(this.symbolAllocator.getTypes().get(symbol).getTypeSignature()))));
                } else {
                    builder.put(key, entry.getValue());
                    builder2.put(key, aggregationNode.getFunctions().get(key));
                }
            }
            return Optional.of(new AggregationNode(this.idAllocator.getNextId(), joinNode, builder.build(), builder2.build(), aggregationNode.getMasks(), ImmutableList.of(joinNode.getLeft().getOutputSymbols()), aggregationNode.getStep(), aggregationNode.getHashSymbol(), Optional.empty()));
        }

        private Optional<DecorrelatedNode> decorrelateFilters(PlanNode planNode, List<Symbol> list) {
            PlanNodeSearcher searchFrom = PlanNodeSearcher.searchFrom(planNode);
            Class<FilterNode> cls = FilterNode.class;
            FilterNode.class.getClass();
            PlanNodeSearcher skipOnlyWhen = searchFrom.where((v1) -> {
                return r1.isInstance(v1);
            }).skipOnlyWhen(Predicates.isInstanceOfAny(ProjectNode.class, LimitNode.class));
            List findAll = skipOnlyWhen.findAll();
            if (findAll.isEmpty()) {
                return decorrelatedNode(ImmutableList.of(), planNode, list);
            }
            if (findAll.size() > 1) {
                return Optional.empty();
            }
            Expression predicate = ((FilterNode) findAll.get(0)).getPredicate();
            if (isSupportedPredicate(predicate) && DependencyExtractor.extractUnique(predicate).containsAll(list)) {
                Map map = (Map) ExpressionUtils.extractConjuncts(predicate).stream().collect(Collectors.partitioningBy(isUsingPredicate(list)));
                ImmutableList copyOf = ImmutableList.copyOf((Collection) map.get(true));
                PlanNode updateFilterNode = updateFilterNode(skipOnlyWhen, ImmutableList.copyOf((Collection) map.get(false)));
                if (!copyOf.isEmpty()) {
                    updateFilterNode = removeLimitNode(updateFilterNode);
                }
                return decorrelatedNode(copyOf, ensureJoinSymbolsAreReturned(updateFilterNode, copyOf), list);
            }
            return Optional.empty();
        }

        private static Optional<DecorrelatedNode> decorrelatedNode(List<Expression> list, PlanNode planNode, List<Symbol> list2) {
            Stream<Symbol> stream = DependencyExtractor.extractUnique(planNode).stream();
            list2.getClass();
            return stream.anyMatch((v1) -> {
                return r1.contains(v1);
            }) ? Optional.empty() : Optional.of(new DecorrelatedNode(list, planNode));
        }

        private static Predicate<Expression> isUsingPredicate(List<Symbol> list) {
            return expression -> {
                Stream stream = list.stream();
                Set<Symbol> extractUnique = DependencyExtractor.extractUnique(expression);
                extractUnique.getClass();
                return stream.anyMatch((v1) -> {
                    return r1.contains(v1);
                });
            };
        }

        private PlanNode updateFilterNode(PlanNodeSearcher planNodeSearcher, List<Expression> list) {
            if (list.isEmpty()) {
                return planNodeSearcher.removeAll();
            }
            return planNodeSearcher.replaceAll(new FilterNode(this.idAllocator.getNextId(), ((FilterNode) Iterables.getOnlyElement(planNodeSearcher.findAll())).getSource(), ExpressionUtils.combineConjuncts(list)));
        }

        private static PlanNode removeLimitNode(PlanNode planNode) {
            PlanNodeSearcher searchFrom = PlanNodeSearcher.searchFrom(planNode);
            Class<LimitNode> cls = LimitNode.class;
            LimitNode.class.getClass();
            PlanNodeSearcher where = searchFrom.where((v1) -> {
                return r1.isInstance(v1);
            });
            Class<ProjectNode> cls2 = ProjectNode.class;
            ProjectNode.class.getClass();
            return where.skipOnlyWhen((v1) -> {
                return r1.isInstance(v1);
            }).removeFirst();
        }

        private PlanNode ensureJoinSymbolsAreReturned(PlanNode planNode, List<Expression> list) {
            return rewriteWith(new ExtendProjectionRewriter(this.idAllocator, DependencyExtractor.extractUnique(list)), planNode);
        }

        private static boolean isSupportedPredicate(Expression expression) {
            AtomicBoolean atomicBoolean = new AtomicBoolean(true);
            new DefaultTraversalVisitor<Void, AtomicBoolean>() { // from class: com.facebook.presto.sql.planner.optimizations.TransformCorrelatedScalarAggregationToJoin.Rewriter.1
                /* JADX INFO: Access modifiers changed from: protected */
                @Override // com.facebook.presto.sql.tree.DefaultTraversalVisitor, com.facebook.presto.sql.tree.AstVisitor
                public Void visitLogicalBinaryExpression(LogicalBinaryExpression logicalBinaryExpression, AtomicBoolean atomicBoolean2) {
                    if (logicalBinaryExpression.getType() == LogicalBinaryExpression.Type.AND) {
                        return null;
                    }
                    atomicBoolean2.set(false);
                    return null;
                }
            }.process(expression, atomicBoolean);
            return atomicBoolean.get();
        }
    }

    public TransformCorrelatedScalarAggregationToJoin(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override // com.facebook.presto.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, Map<Symbol, Type> map, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator) {
        return SimplePlanRewriter.rewriteWith(new Rewriter(planNodeIdAllocator, symbolAllocator, this.metadata), planNode, null);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<Symbol, Expression> toAssignments(Collection<Symbol> collection) {
        return (Map) collection.stream().collect(ImmutableCollectors.toImmutableMap(symbol -> {
            return symbol;
        }, (v0) -> {
            return v0.toSymbolReference();
        }));
    }
}
