package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.DependencyExtractor;
import com.facebook.presto.sql.planner.DeterminismEvaluator;
import com.facebook.presto.sql.planner.EffectivePredicateExtractor;
import com.facebook.presto.sql.planner.EqualityInference;
import com.facebook.presto.sql.planner.ExpressionInterpreter;
import com.facebook.presto.sql.planner.ExpressionSymbolInliner;
import com.facebook.presto.sql.planner.LiteralInterpreter;
import com.facebook.presto.sql.planner.NoOpSymbolResolver;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.MarkDistinctNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SampleNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.SortNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.planner.plan.UnnestNode;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.QualifiedNameReference;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.airlift.log.Logger;
import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/PredicatePushDown.class */
public class PredicatePushDown extends PlanOptimizer {
    private static final Logger log = Logger.get((Class<?>) PredicatePushDown.class);
    private final Metadata metadata;
    private final SqlParser sqlParser;

    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/PredicatePushDown$Rewriter.class */
    private static class Rewriter extends SimplePlanRewriter<Expression> {
        private final SymbolAllocator symbolAllocator;
        private final PlanNodeIdAllocator idAllocator;
        private final Metadata metadata;
        private final SqlParser sqlParser;
        private final Session session;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/PredicatePushDown$Rewriter$InnerJoinPushDownResult.class */
        public static class InnerJoinPushDownResult {
            private final Expression leftPredicate;
            private final Expression rightPredicate;
            private final Expression joinPredicate;
            private final Expression postJoinPredicate;

            private InnerJoinPushDownResult(Expression expression, Expression expression2, Expression expression3, Expression expression4) {
                this.leftPredicate = expression;
                this.rightPredicate = expression2;
                this.joinPredicate = expression3;
                this.postJoinPredicate = expression4;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public Expression getLeftPredicate() {
                return this.leftPredicate;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public Expression getRightPredicate() {
                return this.rightPredicate;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public Expression getJoinPredicate() {
                return this.joinPredicate;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public Expression getPostJoinPredicate() {
                return this.postJoinPredicate;
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/PredicatePushDown$Rewriter$OuterJoinPushDownResult.class */
        public static class OuterJoinPushDownResult {
            private final Expression outerJoinPredicate;
            private final Expression innerJoinPredicate;
            private final Expression postJoinPredicate;

            private OuterJoinPushDownResult(Expression expression, Expression expression2, Expression expression3) {
                this.outerJoinPredicate = expression;
                this.innerJoinPredicate = expression2;
                this.postJoinPredicate = expression3;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public Expression getOuterJoinPredicate() {
                return this.outerJoinPredicate;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public Expression getInnerJoinPredicate() {
                return this.innerJoinPredicate;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public Expression getPostJoinPredicate() {
                return this.postJoinPredicate;
            }
        }

        private Rewriter(SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, Metadata metadata, SqlParser sqlParser, Session session) {
            this.symbolAllocator = (SymbolAllocator) Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
            this.sqlParser = (SqlParser) Objects.requireNonNull(sqlParser, "sqlParser is null");
            this.session = (Session) Objects.requireNonNull(session, "session is null");
        }

        @Override // com.facebook.presto.sql.planner.plan.SimplePlanRewriter, com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitPlan(PlanNode planNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            PlanNode defaultRewrite = rewriteContext.defaultRewrite(planNode, BooleanLiteral.TRUE_LITERAL);
            if (!rewriteContext.get().equals(BooleanLiteral.TRUE_LITERAL)) {
                defaultRewrite = new FilterNode(this.idAllocator.getNextId(), defaultRewrite, rewriteContext.get());
            }
            return defaultRewrite;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitExchange(ExchangeNode exchangeNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            boolean z = false;
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i = 0; i < exchangeNode.getSources().size(); i++) {
                HashMap hashMap = new HashMap();
                for (int i2 = 0; i2 < exchangeNode.getInputs().get(i).size(); i2++) {
                    hashMap.put(exchangeNode.getOutputSymbols().get(i2), exchangeNode.getInputs().get(i).get(i2).toQualifiedNameReference());
                }
                Expression rewriteWith = ExpressionTreeRewriter.rewriteWith(new ExpressionSymbolInliner(hashMap), rewriteContext.get());
                PlanNode planNode = exchangeNode.getSources().get(i);
                PlanNode rewrite = rewriteContext.rewrite(planNode, rewriteWith);
                if (rewrite != planNode) {
                    z = true;
                }
                builder.add((ImmutableList.Builder) rewrite);
            }
            return z ? new ExchangeNode(exchangeNode.getId(), exchangeNode.getType(), exchangeNode.getPartitionFunction(), builder.build(), exchangeNode.getInputs()) : exchangeNode;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitProject(ProjectNode projectNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Set set = (Set) projectNode.getAssignments().entrySet().stream().filter(entry -> {
                return DeterminismEvaluator.isDeterministic((Expression) entry.getValue());
            }).map((v0) -> {
                return v0.getKey();
            }).collect(Collectors.toSet());
            Map map = (Map) ExpressionUtils.extractConjuncts(rewriteContext.get()).stream().collect(Collectors.partitioningBy(expression -> {
                Stream<Symbol> stream = DependencyExtractor.extractAll(expression).stream();
                set.getClass();
                return stream.allMatch((v1) -> {
                    return r1.contains(v1);
                });
            }));
            PlanNode defaultRewrite = rewriteContext.defaultRewrite(projectNode, ExpressionTreeRewriter.rewriteWith(new ExpressionSymbolInliner(projectNode.getAssignments()), ExpressionUtils.combineConjuncts((Iterable<Expression>) map.get(true))));
            if (!((List) map.get(false)).isEmpty()) {
                defaultRewrite = new FilterNode(this.idAllocator.getNextId(), defaultRewrite, ExpressionUtils.combineConjuncts((Iterable<Expression>) map.get(false)));
            }
            return defaultRewrite;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitMarkDistinct(MarkDistinctNode markDistinctNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Preconditions.checkState(!DependencyExtractor.extractUnique(rewriteContext.get()).contains(markDistinctNode.getMarkerSymbol()), "predicate depends on marker symbol");
            return rewriteContext.defaultRewrite(markDistinctNode, rewriteContext.get());
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitSort(SortNode sortNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            return rewriteContext.defaultRewrite(sortNode, rewriteContext.get());
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitUnion(UnionNode unionNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            boolean z = false;
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i = 0; i < unionNode.getSources().size(); i++) {
                Expression rewriteWith = ExpressionTreeRewriter.rewriteWith(new ExpressionSymbolInliner(unionNode.sourceSymbolMap(i)), rewriteContext.get());
                PlanNode planNode = unionNode.getSources().get(i);
                PlanNode rewrite = rewriteContext.rewrite(planNode, rewriteWith);
                if (rewrite != planNode) {
                    z = true;
                }
                builder.add((ImmutableList.Builder) rewrite);
            }
            return z ? new UnionNode(unionNode.getId(), builder.build(), unionNode.getSymbolMapping(), unionNode.getOutputSymbols()) : unionNode;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitFilter(FilterNode filterNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            return rewriteContext.rewrite(filterNode.getSource(), ExpressionUtils.combineConjuncts(filterNode.getPredicate(), rewriteContext.get()));
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitJoin(JoinNode joinNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Expression expression;
            Expression expression2;
            Expression expression3;
            Expression expression4;
            Expression expression5 = rewriteContext.get();
            JoinNode tryNormalizeToOuterToInnerJoin = tryNormalizeToOuterToInnerJoin(joinNode, expression5);
            Expression extract = EffectivePredicateExtractor.extract(tryNormalizeToOuterToInnerJoin.getLeft(), this.symbolAllocator.getTypes());
            Expression extract2 = EffectivePredicateExtractor.extract(tryNormalizeToOuterToInnerJoin.getRight(), this.symbolAllocator.getTypes());
            Expression extractJoinPredicate = extractJoinPredicate(tryNormalizeToOuterToInnerJoin);
            switch (tryNormalizeToOuterToInnerJoin.getType()) {
                case INNER:
                    InnerJoinPushDownResult processInnerJoin = processInnerJoin(expression5, extract, extract2, extractJoinPredicate, tryNormalizeToOuterToInnerJoin.getLeft().getOutputSymbols());
                    expression = processInnerJoin.getLeftPredicate();
                    expression2 = processInnerJoin.getRightPredicate();
                    expression3 = processInnerJoin.getPostJoinPredicate();
                    expression4 = processInnerJoin.getJoinPredicate();
                    break;
                case LEFT:
                    OuterJoinPushDownResult processLimitedOuterJoin = processLimitedOuterJoin(expression5, extract, extract2, extractJoinPredicate, tryNormalizeToOuterToInnerJoin.getLeft().getOutputSymbols());
                    expression = processLimitedOuterJoin.getOuterJoinPredicate();
                    expression2 = processLimitedOuterJoin.getInnerJoinPredicate();
                    expression3 = processLimitedOuterJoin.getPostJoinPredicate();
                    expression4 = extractJoinPredicate;
                    break;
                case RIGHT:
                    OuterJoinPushDownResult processLimitedOuterJoin2 = processLimitedOuterJoin(expression5, extract2, extract, extractJoinPredicate, tryNormalizeToOuterToInnerJoin.getRight().getOutputSymbols());
                    expression = processLimitedOuterJoin2.getInnerJoinPredicate();
                    expression2 = processLimitedOuterJoin2.getOuterJoinPredicate();
                    expression3 = processLimitedOuterJoin2.getPostJoinPredicate();
                    expression4 = extractJoinPredicate;
                    break;
                case FULL:
                    expression = BooleanLiteral.TRUE_LITERAL;
                    expression2 = BooleanLiteral.TRUE_LITERAL;
                    expression3 = expression5;
                    expression4 = extractJoinPredicate;
                    break;
                default:
                    throw new UnsupportedOperationException("Unsupported join type: " + tryNormalizeToOuterToInnerJoin.getType());
            }
            PlanNode rewrite = rewriteContext.rewrite(tryNormalizeToOuterToInnerJoin.getLeft(), expression);
            PlanNode rewrite2 = rewriteContext.rewrite(tryNormalizeToOuterToInnerJoin.getRight(), expression2);
            PlanNode planNode = tryNormalizeToOuterToInnerJoin;
            if (rewrite != tryNormalizeToOuterToInnerJoin.getLeft() || rewrite2 != tryNormalizeToOuterToInnerJoin.getRight() || !expression4.equals(extractJoinPredicate)) {
                List<JoinNode.EquiJoinClause> criteria = tryNormalizeToOuterToInnerJoin.getCriteria();
                Iterable<Expression> iterable = null;
                if (!expression4.equals(extractJoinPredicate)) {
                    ImmutableMap.Builder builder = ImmutableMap.builder();
                    builder.putAll((Map) tryNormalizeToOuterToInnerJoin.getLeft().getOutputSymbols().stream().collect(Collectors.toMap(symbol -> {
                        return symbol;
                    }, (v0) -> {
                        return v0.toQualifiedNameReference();
                    })));
                    ImmutableMap.Builder builder2 = ImmutableMap.builder();
                    builder2.putAll((Map) tryNormalizeToOuterToInnerJoin.getRight().getOutputSymbols().stream().collect(Collectors.toMap(symbol2 -> {
                        return symbol2;
                    }, (v0) -> {
                        return v0.toQualifiedNameReference();
                    })));
                    iterable = Iterables.filter(Iterables.transform(ExpressionUtils.extractConjuncts(expression4), this::simplifyExpression), Predicates.not(Predicates.equalTo(BooleanLiteral.TRUE_LITERAL)));
                    ImmutableList.Builder builder3 = ImmutableList.builder();
                    for (Expression expression6 : iterable) {
                        Preconditions.checkState(joinEqualityExpression(tryNormalizeToOuterToInnerJoin.getLeft().getOutputSymbols()).apply(expression6), "Expected join predicate to be a valid join equality");
                        ComparisonExpression comparisonExpression = (ComparisonExpression) expression6;
                        boolean all = Iterables.all(DependencyExtractor.extractUnique(comparisonExpression.getLeft()), Predicates.in(tryNormalizeToOuterToInnerJoin.getLeft().getOutputSymbols()));
                        Expression left = all ? comparisonExpression.getLeft() : comparisonExpression.getRight();
                        Expression right = all ? comparisonExpression.getRight() : comparisonExpression.getLeft();
                        Symbol newSymbol = this.symbolAllocator.newSymbol(left, extractType(left));
                        builder.put(newSymbol, left);
                        Symbol newSymbol2 = this.symbolAllocator.newSymbol(right, extractType(right));
                        builder2.put(newSymbol2, right);
                        builder3.add((ImmutableList.Builder) new JoinNode.EquiJoinClause(newSymbol, newSymbol2));
                    }
                    rewrite = new ProjectNode(this.idAllocator.getNextId(), rewrite, builder.build());
                    rewrite2 = new ProjectNode(this.idAllocator.getNextId(), rewrite2, builder2.build());
                    criteria = builder3.build();
                }
                planNode = (iterable == null || !Iterables.isEmpty(iterable)) ? new JoinNode(tryNormalizeToOuterToInnerJoin.getId(), tryNormalizeToOuterToInnerJoin.getType(), rewrite, rewrite2, criteria, tryNormalizeToOuterToInnerJoin.getLeftHashSymbol(), tryNormalizeToOuterToInnerJoin.getRightHashSymbol()) : new JoinNode(tryNormalizeToOuterToInnerJoin.getId(), JoinNode.Type.INNER, rewrite, rewrite2, criteria, Optional.empty(), Optional.empty());
            }
            if (!expression3.equals(BooleanLiteral.TRUE_LITERAL)) {
                planNode = new FilterNode(this.idAllocator.getNextId(), planNode, expression3);
            }
            return planNode;
        }

        private OuterJoinPushDownResult processLimitedOuterJoin(Expression expression, Expression expression2, Expression expression3, Expression expression4, Collection<Symbol> collection) {
            Preconditions.checkArgument(Iterables.all(DependencyExtractor.extractUnique(expression2), Predicates.in(collection)), "outerEffectivePredicate must only contain symbols from outerSymbols");
            Preconditions.checkArgument(Iterables.all(DependencyExtractor.extractUnique(expression3), Predicates.not(Predicates.in(collection))), "innerEffectivePredicate must not contain symbols from outerSymbols");
            ImmutableList.Builder builder = ImmutableList.builder();
            ImmutableList.Builder builder2 = ImmutableList.builder();
            ImmutableList.Builder builder3 = ImmutableList.builder();
            builder3.addAll(Iterables.filter(ExpressionUtils.extractConjuncts(expression), Predicates.not(DeterminismEvaluator::isDeterministic)));
            Expression stripNonDeterministicConjuncts = ExpressionUtils.stripNonDeterministicConjuncts(expression);
            Expression stripNonDeterministicConjuncts2 = ExpressionUtils.stripNonDeterministicConjuncts(expression2);
            Expression stripNonDeterministicConjuncts3 = ExpressionUtils.stripNonDeterministicConjuncts(expression3);
            Expression stripNonDeterministicConjuncts4 = ExpressionUtils.stripNonDeterministicConjuncts(expression4);
            EqualityInference createEqualityInference = EqualityInference.createEqualityInference(stripNonDeterministicConjuncts);
            EqualityInference createEqualityInference2 = EqualityInference.createEqualityInference(stripNonDeterministicConjuncts, stripNonDeterministicConjuncts2);
            EqualityInference.EqualityPartition generateEqualitiesPartitionedBy = createEqualityInference.generateEqualitiesPartitionedBy(Predicates.in(collection));
            Expression combineConjuncts = ExpressionUtils.combineConjuncts(generateEqualitiesPartitionedBy.getScopeEqualities());
            EqualityInference createEqualityInference3 = EqualityInference.createEqualityInference(combineConjuncts, stripNonDeterministicConjuncts2, stripNonDeterministicConjuncts3, stripNonDeterministicConjuncts4);
            EqualityInference createEqualityInference4 = EqualityInference.createEqualityInference(combineConjuncts, stripNonDeterministicConjuncts2, stripNonDeterministicConjuncts4);
            for (Expression expression5 : EqualityInference.nonInferrableConjuncts(stripNonDeterministicConjuncts)) {
                Expression rewriteExpression = createEqualityInference2.rewriteExpression(expression5, Predicates.in(collection));
                if (rewriteExpression != null) {
                    builder.add((ImmutableList.Builder) rewriteExpression);
                    Expression rewriteExpression2 = createEqualityInference3.rewriteExpression(rewriteExpression, Predicates.not(Predicates.in(collection)));
                    if (rewriteExpression2 != null) {
                        builder2.add((ImmutableList.Builder) rewriteExpression2);
                    }
                } else {
                    builder3.add((ImmutableList.Builder) expression5);
                }
            }
            Iterator<Expression> it2 = EqualityInference.nonInferrableConjuncts(ExpressionUtils.and(stripNonDeterministicConjuncts2, stripNonDeterministicConjuncts4)).iterator();
            while (it2.hasNext()) {
                Expression rewriteExpression3 = createEqualityInference3.rewriteExpression(it2.next(), Predicates.not(Predicates.in(collection)));
                if (rewriteExpression3 != null) {
                    builder2.add((ImmutableList.Builder) rewriteExpression3);
                }
            }
            builder.addAll((Iterable) generateEqualitiesPartitionedBy.getScopeEqualities());
            builder3.addAll((Iterable) generateEqualitiesPartitionedBy.getScopeComplementEqualities());
            builder3.addAll((Iterable) generateEqualitiesPartitionedBy.getScopeStraddlingEqualities());
            builder2.addAll((Iterable) createEqualityInference4.generateEqualitiesPartitionedBy(Predicates.not(Predicates.in(collection))).getScopeEqualities());
            return new OuterJoinPushDownResult(ExpressionUtils.combineConjuncts(builder.build()), ExpressionUtils.combineConjuncts(builder2.build()), ExpressionUtils.combineConjuncts(builder3.build()));
        }

        private InnerJoinPushDownResult processInnerJoin(Expression expression, Expression expression2, Expression expression3, Expression expression4, Collection<Symbol> collection) {
            Preconditions.checkArgument(Iterables.all(DependencyExtractor.extractUnique(expression2), Predicates.in(collection)), "leftEffectivePredicate must only contain symbols from leftSymbols");
            Preconditions.checkArgument(Iterables.all(DependencyExtractor.extractUnique(expression3), Predicates.not(Predicates.in(collection))), "rightEffectivePredicate must not contain symbols from leftSymbols");
            ImmutableList.Builder builder = ImmutableList.builder();
            ImmutableList.Builder builder2 = ImmutableList.builder();
            ImmutableList.Builder builder3 = ImmutableList.builder();
            builder3.addAll(Iterables.filter(ExpressionUtils.extractConjuncts(expression), Predicates.not(DeterminismEvaluator::isDeterministic)));
            Expression stripNonDeterministicConjuncts = ExpressionUtils.stripNonDeterministicConjuncts(expression);
            builder3.addAll(Iterables.filter(ExpressionUtils.extractConjuncts(expression4), Predicates.not(DeterminismEvaluator::isDeterministic)));
            Expression stripNonDeterministicConjuncts2 = ExpressionUtils.stripNonDeterministicConjuncts(expression4);
            Expression stripNonDeterministicConjuncts3 = ExpressionUtils.stripNonDeterministicConjuncts(expression2);
            Expression stripNonDeterministicConjuncts4 = ExpressionUtils.stripNonDeterministicConjuncts(expression3);
            EqualityInference createEqualityInference = EqualityInference.createEqualityInference(stripNonDeterministicConjuncts, stripNonDeterministicConjuncts3, stripNonDeterministicConjuncts4, stripNonDeterministicConjuncts2);
            EqualityInference createEqualityInference2 = EqualityInference.createEqualityInference(stripNonDeterministicConjuncts, stripNonDeterministicConjuncts4, stripNonDeterministicConjuncts2);
            EqualityInference createEqualityInference3 = EqualityInference.createEqualityInference(stripNonDeterministicConjuncts, stripNonDeterministicConjuncts3, stripNonDeterministicConjuncts2);
            for (Expression expression5 : EqualityInference.nonInferrableConjuncts(stripNonDeterministicConjuncts)) {
                Expression rewriteExpression = createEqualityInference.rewriteExpression(expression5, Predicates.in(collection));
                if (rewriteExpression != null) {
                    builder.add((ImmutableList.Builder) rewriteExpression);
                }
                Expression rewriteExpression2 = createEqualityInference.rewriteExpression(expression5, Predicates.not(Predicates.in(collection)));
                if (rewriteExpression2 != null) {
                    builder2.add((ImmutableList.Builder) rewriteExpression2);
                }
                if (rewriteExpression == null && rewriteExpression2 == null) {
                    builder3.add((ImmutableList.Builder) expression5);
                }
            }
            Iterator<Expression> it2 = EqualityInference.nonInferrableConjuncts(stripNonDeterministicConjuncts4).iterator();
            while (it2.hasNext()) {
                Expression rewriteExpression3 = createEqualityInference.rewriteExpression(it2.next(), Predicates.in(collection));
                if (rewriteExpression3 != null) {
                    builder.add((ImmutableList.Builder) rewriteExpression3);
                }
            }
            Iterator<Expression> it3 = EqualityInference.nonInferrableConjuncts(stripNonDeterministicConjuncts3).iterator();
            while (it3.hasNext()) {
                Expression rewriteExpression4 = createEqualityInference.rewriteExpression(it3.next(), Predicates.not(Predicates.in(collection)));
                if (rewriteExpression4 != null) {
                    builder2.add((ImmutableList.Builder) rewriteExpression4);
                }
            }
            for (Expression expression6 : EqualityInference.nonInferrableConjuncts(stripNonDeterministicConjuncts2)) {
                Expression rewriteExpression5 = createEqualityInference.rewriteExpression(expression6, Predicates.in(collection));
                if (rewriteExpression5 != null) {
                    builder.add((ImmutableList.Builder) rewriteExpression5);
                }
                Expression rewriteExpression6 = createEqualityInference.rewriteExpression(expression6, Predicates.not(Predicates.in(collection)));
                if (rewriteExpression6 != null) {
                    builder2.add((ImmutableList.Builder) rewriteExpression6);
                }
                if (rewriteExpression5 == null && rewriteExpression6 == null) {
                    builder3.add((ImmutableList.Builder) expression6);
                }
            }
            builder.addAll((Iterable) createEqualityInference2.generateEqualitiesPartitionedBy(Predicates.in(collection)).getScopeEqualities());
            builder2.addAll((Iterable) createEqualityInference3.generateEqualitiesPartitionedBy(Predicates.not(Predicates.in(collection))).getScopeEqualities());
            builder3.addAll((Iterable) createEqualityInference.generateEqualitiesPartitionedBy(Predicates.in(collection)).getScopeStraddlingEqualities());
            ImmutableList build = builder3.build();
            return new InnerJoinPushDownResult(ExpressionUtils.combineConjuncts(builder.build()), ExpressionUtils.combineConjuncts(builder2.build()), ExpressionUtils.combineConjuncts(ImmutableList.copyOf(Iterables.filter(build, joinEqualityExpression(collection)))), ExpressionUtils.combineConjuncts(ImmutableList.copyOf(Iterables.filter(build, Predicates.not(joinEqualityExpression(collection))))));
        }

        private static Expression extractJoinPredicate(JoinNode joinNode) {
            ImmutableList.Builder builder = ImmutableList.builder();
            for (JoinNode.EquiJoinClause equiJoinClause : joinNode.getCriteria()) {
                builder.add((ImmutableList.Builder) equalsExpression(equiJoinClause.getLeft(), equiJoinClause.getRight()));
            }
            return ExpressionUtils.combineConjuncts(builder.build());
        }

        private static Expression equalsExpression(Symbol symbol, Symbol symbol2) {
            return new ComparisonExpression(ComparisonExpression.Type.EQUAL, new QualifiedNameReference(symbol.toQualifiedName()), new QualifiedNameReference(symbol2.toQualifiedName()));
        }

        private Type extractType(Expression expression) {
            return ExpressionAnalyzer.getExpressionTypes(this.session, this.metadata, this.sqlParser, this.symbolAllocator.getTypes(), expression).get(expression);
        }

        private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode joinNode, Expression expression) {
            Preconditions.checkArgument(EnumSet.of(JoinNode.Type.INNER, JoinNode.Type.RIGHT, JoinNode.Type.LEFT, JoinNode.Type.FULL).contains(joinNode.getType()), "Unsupported join type: %s", joinNode.getType());
            if (joinNode.getType() == JoinNode.Type.INNER) {
                return joinNode;
            }
            if (joinNode.getType() != JoinNode.Type.FULL) {
                return ((joinNode.getType() != JoinNode.Type.LEFT || canConvertOuterToInner(joinNode.getRight().getOutputSymbols(), expression)) && (joinNode.getType() != JoinNode.Type.RIGHT || canConvertOuterToInner(joinNode.getLeft().getOutputSymbols(), expression))) ? new JoinNode(joinNode.getId(), JoinNode.Type.INNER, joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol()) : joinNode;
            }
            boolean canConvertOuterToInner = canConvertOuterToInner(joinNode.getLeft().getOutputSymbols(), expression);
            boolean canConvertOuterToInner2 = canConvertOuterToInner(joinNode.getRight().getOutputSymbols(), expression);
            if (!canConvertOuterToInner && !canConvertOuterToInner2) {
                return joinNode;
            }
            if (canConvertOuterToInner && canConvertOuterToInner2) {
                return new JoinNode(joinNode.getId(), JoinNode.Type.INNER, joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol());
            }
            return new JoinNode(joinNode.getId(), canConvertOuterToInner ? JoinNode.Type.LEFT : JoinNode.Type.RIGHT, joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol());
        }

        private boolean canConvertOuterToInner(List<Symbol> list, Expression expression) {
            Object nullInputEvaluator;
            ImmutableSet copyOf = ImmutableSet.copyOf((Collection) list);
            for (Expression expression2 : ExpressionUtils.extractConjuncts(expression)) {
                if (DeterminismEvaluator.isDeterministic(expression2) && ((nullInputEvaluator = nullInputEvaluator(copyOf, expression2)) == null || (nullInputEvaluator instanceof NullLiteral) || Boolean.FALSE.equals(nullInputEvaluator))) {
                    return true;
                }
            }
            return false;
        }

        private Expression simplifyExpression(Expression expression) {
            IdentityHashMap<Expression, Type> expressionTypes = ExpressionAnalyzer.getExpressionTypes(this.session, this.metadata, this.sqlParser, this.symbolAllocator.getTypes(), expression);
            return LiteralInterpreter.toExpression(ExpressionInterpreter.expressionOptimizer(expression, this.metadata, this.session, expressionTypes).optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(expression));
        }

        private Object nullInputEvaluator(Collection<Symbol> collection, Expression expression) {
            return ExpressionInterpreter.expressionOptimizer(expression, this.metadata, this.session, ExpressionAnalyzer.getExpressionTypes(this.session, this.metadata, this.sqlParser, this.symbolAllocator.getTypes(), expression)).optimize(symbol -> {
                if (collection.contains(symbol)) {
                    return null;
                }
                return new QualifiedNameReference(symbol.toQualifiedName());
            });
        }

        private static Predicate<Expression> joinEqualityExpression(Collection<Symbol> collection) {
            return expression -> {
                if (!DeterminismEvaluator.isDeterministic(expression) || !(expression instanceof ComparisonExpression)) {
                    return false;
                }
                ComparisonExpression comparisonExpression = (ComparisonExpression) expression;
                if (comparisonExpression.getType() != ComparisonExpression.Type.EQUAL) {
                    return false;
                }
                Set<Symbol> extractUnique = DependencyExtractor.extractUnique(comparisonExpression.getLeft());
                Set<Symbol> extractUnique2 = DependencyExtractor.extractUnique(comparisonExpression.getRight());
                return (Iterables.all(extractUnique, Predicates.in(collection)) && Iterables.all(extractUnique2, Predicates.not(Predicates.in(collection)))) || (Iterables.all(extractUnique2, Predicates.in(collection)) && Iterables.all(extractUnique, Predicates.not(Predicates.in(collection))));
            };
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitSemiJoin(SemiJoinNode semiJoinNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Expression expression = rewriteContext.get();
            Expression extract = EffectivePredicateExtractor.extract(semiJoinNode.getSource(), this.symbolAllocator.getTypes());
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            EqualityInference createEqualityInference = EqualityInference.createEqualityInference(expression, extract, equalsExpression(semiJoinNode.getSourceJoinSymbol(), semiJoinNode.getFilteringSourceJoinSymbol()));
            Iterator it2 = Iterables.concat(EqualityInference.nonInferrableConjuncts(expression), EqualityInference.nonInferrableConjuncts(extract)).iterator();
            while (it2.hasNext()) {
                Expression rewriteExpression = createEqualityInference.rewriteExpression((Expression) it2.next(), Predicates.equalTo(semiJoinNode.getFilteringSourceJoinSymbol()));
                if (rewriteExpression != null && DeterminismEvaluator.isDeterministic(rewriteExpression)) {
                    arrayList2.add(ExpressionUtils.expressionOrNullSymbols(Predicates.equalTo(semiJoinNode.getFilteringSourceJoinSymbol())).apply(rewriteExpression));
                }
            }
            arrayList2.addAll(ImmutableList.copyOf(Iterables.transform(createEqualityInference.generateEqualitiesPartitionedBy(Predicates.equalTo(semiJoinNode.getFilteringSourceJoinSymbol())).getScopeEqualities(), ExpressionUtils.expressionOrNullSymbols(Predicates.equalTo(semiJoinNode.getFilteringSourceJoinSymbol())))));
            EqualityInference createEqualityInference2 = EqualityInference.createEqualityInference(expression);
            for (Expression expression2 : EqualityInference.nonInferrableConjuncts(expression)) {
                Expression rewriteExpression2 = createEqualityInference2.rewriteExpression(expression2, Predicates.in(semiJoinNode.getSource().getOutputSymbols()));
                if (rewriteExpression2 != null) {
                    arrayList.add(rewriteExpression2);
                } else {
                    arrayList3.add(expression2);
                }
            }
            EqualityInference.EqualityPartition generateEqualitiesPartitionedBy = createEqualityInference2.generateEqualitiesPartitionedBy(Predicates.in(semiJoinNode.getSource().getOutputSymbols()));
            arrayList.addAll(generateEqualitiesPartitionedBy.getScopeEqualities());
            arrayList3.addAll(generateEqualitiesPartitionedBy.getScopeComplementEqualities());
            arrayList3.addAll(generateEqualitiesPartitionedBy.getScopeStraddlingEqualities());
            PlanNode rewrite = rewriteContext.rewrite(semiJoinNode.getSource(), ExpressionUtils.combineConjuncts(arrayList));
            PlanNode rewrite2 = rewriteContext.rewrite(semiJoinNode.getFilteringSource(), ExpressionUtils.combineConjuncts(arrayList2));
            PlanNode planNode = semiJoinNode;
            if (rewrite != semiJoinNode.getSource() || rewrite2 != semiJoinNode.getFilteringSource()) {
                planNode = new SemiJoinNode(semiJoinNode.getId(), rewrite, rewrite2, semiJoinNode.getSourceJoinSymbol(), semiJoinNode.getFilteringSourceJoinSymbol(), semiJoinNode.getSemiJoinOutput(), semiJoinNode.getSourceHashSymbol(), semiJoinNode.getFilteringSourceHashSymbol());
            }
            if (!arrayList3.isEmpty()) {
                planNode = new FilterNode(this.idAllocator.getNextId(), planNode, ExpressionUtils.combineConjuncts(arrayList3));
            }
            return planNode;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitAggregation(AggregationNode aggregationNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            if (aggregationNode.getGroupBy().isEmpty()) {
                return visitPlan((PlanNode) aggregationNode, rewriteContext);
            }
            Expression expression = rewriteContext.get();
            EqualityInference createEqualityInference = EqualityInference.createEqualityInference(expression);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            arrayList2.addAll(ImmutableList.copyOf(Iterables.filter(ExpressionUtils.extractConjuncts(expression), Predicates.not(DeterminismEvaluator::isDeterministic))));
            for (Expression expression2 : EqualityInference.nonInferrableConjuncts(ExpressionUtils.stripNonDeterministicConjuncts(expression))) {
                Expression rewriteExpression = createEqualityInference.rewriteExpression(expression2, Predicates.in(aggregationNode.getGroupBy()));
                if (rewriteExpression != null) {
                    arrayList.add(rewriteExpression);
                } else {
                    arrayList2.add(expression2);
                }
            }
            EqualityInference.EqualityPartition generateEqualitiesPartitionedBy = createEqualityInference.generateEqualitiesPartitionedBy(Predicates.in(aggregationNode.getGroupBy()));
            arrayList.addAll(generateEqualitiesPartitionedBy.getScopeEqualities());
            arrayList2.addAll(generateEqualitiesPartitionedBy.getScopeComplementEqualities());
            arrayList2.addAll(generateEqualitiesPartitionedBy.getScopeStraddlingEqualities());
            PlanNode rewrite = rewriteContext.rewrite(aggregationNode.getSource(), ExpressionUtils.combineConjuncts(arrayList));
            PlanNode planNode = aggregationNode;
            if (rewrite != aggregationNode.getSource()) {
                planNode = new AggregationNode(aggregationNode.getId(), rewrite, aggregationNode.getGroupBy(), aggregationNode.getAggregations(), aggregationNode.getFunctions(), aggregationNode.getMasks(), aggregationNode.getStep(), aggregationNode.getSampleWeight(), aggregationNode.getConfidence(), aggregationNode.getHashSymbol());
            }
            if (!arrayList2.isEmpty()) {
                planNode = new FilterNode(this.idAllocator.getNextId(), planNode, ExpressionUtils.combineConjuncts(arrayList2));
            }
            return planNode;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitUnnest(UnnestNode unnestNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Expression expression = rewriteContext.get();
            EqualityInference createEqualityInference = EqualityInference.createEqualityInference(expression);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            arrayList2.addAll(ImmutableList.copyOf(Iterables.filter(ExpressionUtils.extractConjuncts(expression), Predicates.not(DeterminismEvaluator::isDeterministic))));
            for (Expression expression2 : EqualityInference.nonInferrableConjuncts(ExpressionUtils.stripNonDeterministicConjuncts(expression))) {
                Expression rewriteExpression = createEqualityInference.rewriteExpression(expression2, Predicates.in(unnestNode.getReplicateSymbols()));
                if (rewriteExpression != null) {
                    arrayList.add(rewriteExpression);
                } else {
                    arrayList2.add(expression2);
                }
            }
            EqualityInference.EqualityPartition generateEqualitiesPartitionedBy = createEqualityInference.generateEqualitiesPartitionedBy(Predicates.in(unnestNode.getReplicateSymbols()));
            arrayList.addAll(generateEqualitiesPartitionedBy.getScopeEqualities());
            arrayList2.addAll(generateEqualitiesPartitionedBy.getScopeComplementEqualities());
            arrayList2.addAll(generateEqualitiesPartitionedBy.getScopeStraddlingEqualities());
            PlanNode rewrite = rewriteContext.rewrite(unnestNode.getSource(), ExpressionUtils.combineConjuncts(arrayList));
            PlanNode planNode = unnestNode;
            if (rewrite != unnestNode.getSource()) {
                planNode = new UnnestNode(unnestNode.getId(), rewrite, unnestNode.getReplicateSymbols(), unnestNode.getUnnestSymbols(), unnestNode.getOrdinalitySymbol());
            }
            if (!arrayList2.isEmpty()) {
                planNode = new FilterNode(this.idAllocator.getNextId(), planNode, ExpressionUtils.combineConjuncts(arrayList2));
            }
            return planNode;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitSample(SampleNode sampleNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            return rewriteContext.defaultRewrite(sampleNode, rewriteContext.get());
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitTableScan(TableScanNode tableScanNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Expression simplifyExpression = simplifyExpression(rewriteContext.get());
            return !BooleanLiteral.TRUE_LITERAL.equals(simplifyExpression) ? new FilterNode(this.idAllocator.getNextId(), tableScanNode, simplifyExpression) : tableScanNode;
        }
    }

    public PredicatePushDown(Metadata metadata, SqlParser sqlParser) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        this.sqlParser = (SqlParser) Objects.requireNonNull(sqlParser, "sqlParser is null");
    }

    @Override // com.facebook.presto.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, Map<Symbol, Type> map, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator) {
        Objects.requireNonNull(planNode, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(map, "types is null");
        Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(symbolAllocator, planNodeIdAllocator, this.metadata, this.sqlParser, session), planNode, BooleanLiteral.TRUE_LITERAL);
    }
}
