package com.facebook.presto.sql.planner;

import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.DistinctLimitNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanVisitor;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SortNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.TopNNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.util.ImmutableCollectors;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;

/* loaded from: input_file:com/facebook/presto/sql/planner/EffectivePredicateExtractor.class */
public class EffectivePredicateExtractor extends PlanVisitor<Void, Expression> {
    private static final Predicate<Map.Entry<Symbol, ? extends Expression>> SYMBOL_MATCHES_EXPRESSION = entry -> {
        return ((Expression) entry.getValue()).equals(((Symbol) entry.getKey()).toSymbolReference());
    };
    private static final Function<Map.Entry<Symbol, ? extends Expression>, Expression> ENTRY_TO_EQUALITY = entry -> {
        return new ComparisonExpression(ComparisonExpressionType.EQUAL, ((Symbol) entry.getKey()).toSymbolReference(), (Expression) entry.getValue());
    };
    private final Map<Symbol, Type> symbolTypes;

    public static Expression extract(PlanNode planNode, Map<Symbol, Type> map) {
        return (Expression) planNode.accept(new EffectivePredicateExtractor(map), null);
    }

    public EffectivePredicateExtractor(Map<Symbol, Type> map) {
        this.symbolTypes = map;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitPlan(PlanNode planNode, Void r4) {
        return BooleanLiteral.TRUE_LITERAL;
    }

    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitAggregation(AggregationNode aggregationNode, Void r6) {
        return aggregationNode.getGroupingKeys().isEmpty() ? BooleanLiteral.TRUE_LITERAL : pullExpressionThroughSymbols((Expression) aggregationNode.getSource().accept(this, r6), aggregationNode.getGroupingKeys());
    }

    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitFilter(FilterNode filterNode, Void r7) {
        return ExpressionUtils.combineConjuncts(ExpressionUtils.stripNonDeterministicConjuncts(filterNode.getPredicate()), (Expression) filterNode.getSource().accept(this, r7));
    }

    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitExchange(ExchangeNode exchangeNode, Void r6) {
        return deriveCommonPredicates(exchangeNode, num -> {
            HashMap hashMap = new HashMap();
            for (int i = 0; i < exchangeNode.getInputs().get(num.intValue()).size(); i++) {
                hashMap.put(exchangeNode.getOutputSymbols().get(i), exchangeNode.getInputs().get(num.intValue()).get(i).toSymbolReference());
            }
            return hashMap.entrySet();
        });
    }

    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitProject(ProjectNode projectNode, Void r6) {
        return pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts(ImmutableList.builder().addAll((Iterable) projectNode.getAssignments().entrySet().stream().filter(SYMBOL_MATCHES_EXPRESSION.negate()).map(ENTRY_TO_EQUALITY).collect(ImmutableCollectors.toImmutableList())).add((ImmutableList.Builder) projectNode.getSource().accept(this, r6)).build()), projectNode.getOutputSymbols());
    }

    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitTopN(TopNNode topNNode, Void r6) {
        return (Expression) topNNode.getSource().accept(this, r6);
    }

    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitLimit(LimitNode limitNode, Void r6) {
        return (Expression) limitNode.getSource().accept(this, r6);
    }

    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitDistinctLimit(DistinctLimitNode distinctLimitNode, Void r6) {
        return (Expression) distinctLimitNode.getSource().accept(this, r6);
    }

    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitTableScan(TableScanNode tableScanNode, Void r6) {
        ImmutableBiMap inverse = ImmutableBiMap.copyOf((Map) tableScanNode.getAssignments()).inverse();
        TupleDomain<ColumnHandle> spanTupleDomain = spanTupleDomain(tableScanNode.getCurrentConstraint());
        inverse.getClass();
        return DomainTranslator.toPredicate(spanTupleDomain.transform((v1) -> {
            return r1.get(v1);
        }));
    }

    private static TupleDomain<ColumnHandle> spanTupleDomain(TupleDomain<ColumnHandle> tupleDomain) {
        return tupleDomain.isNone() ? tupleDomain : TupleDomain.withColumnDomains(Maps.transformValues(tupleDomain.getDomains().get(), DomainUtils::simplifyDomain));
    }

    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitSort(SortNode sortNode, Void r6) {
        return (Expression) sortNode.getSource().accept(this, r6);
    }

    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitWindow(WindowNode windowNode, Void r6) {
        return (Expression) windowNode.getSource().accept(this, r6);
    }

    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitUnion(UnionNode unionNode, Void r6) {
        return deriveCommonPredicates(unionNode, num -> {
            return unionNode.outputSymbolMap(num.intValue()).entries();
        });
    }

    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitJoin(JoinNode joinNode, Void r10) {
        Expression expression = (Expression) joinNode.getLeft().accept(this, r10);
        Expression expression2 = (Expression) joinNode.getRight().accept(this, r10);
        ArrayList arrayList = new ArrayList();
        for (JoinNode.EquiJoinClause equiJoinClause : joinNode.getCriteria()) {
            arrayList.add(new ComparisonExpression(ComparisonExpressionType.EQUAL, equiJoinClause.getLeft().toSymbolReference(), equiJoinClause.getRight().toSymbolReference()));
        }
        switch (joinNode.getType()) {
            case INNER:
                return ExpressionUtils.combineConjuncts(ImmutableList.builder().add((ImmutableList.Builder) expression).add((ImmutableList.Builder) expression2).addAll((Iterable) arrayList).build());
            case LEFT:
                ImmutableList.Builder add = ImmutableList.builder().add((ImmutableList.Builder) expression);
                List<Expression> extractConjuncts = ExpressionUtils.extractConjuncts(expression2);
                List<Symbol> outputSymbols = joinNode.getRight().getOutputSymbols();
                outputSymbols.getClass();
                ImmutableList.Builder addAll = add.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts, (v1) -> {
                    return r5.contains(v1);
                }));
                List<Symbol> outputSymbols2 = joinNode.getRight().getOutputSymbols();
                outputSymbols2.getClass();
                return ExpressionUtils.combineConjuncts(addAll.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(arrayList, (v1) -> {
                    return r5.contains(v1);
                })).build());
            case RIGHT:
                ImmutableList.Builder add2 = ImmutableList.builder().add((ImmutableList.Builder) expression2);
                List<Expression> extractConjuncts2 = ExpressionUtils.extractConjuncts(expression);
                List<Symbol> outputSymbols3 = joinNode.getLeft().getOutputSymbols();
                outputSymbols3.getClass();
                ImmutableList.Builder addAll2 = add2.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts2, (v1) -> {
                    return r5.contains(v1);
                }));
                List<Symbol> outputSymbols4 = joinNode.getLeft().getOutputSymbols();
                outputSymbols4.getClass();
                return ExpressionUtils.combineConjuncts(addAll2.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(arrayList, (v1) -> {
                    return r5.contains(v1);
                })).build());
            case FULL:
                ImmutableList.Builder builder = ImmutableList.builder();
                List<Expression> extractConjuncts3 = ExpressionUtils.extractConjuncts(expression);
                List<Symbol> outputSymbols5 = joinNode.getLeft().getOutputSymbols();
                outputSymbols5.getClass();
                ImmutableList.Builder addAll3 = builder.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts3, (v1) -> {
                    return r5.contains(v1);
                }));
                List<Expression> extractConjuncts4 = ExpressionUtils.extractConjuncts(expression2);
                List<Symbol> outputSymbols6 = joinNode.getRight().getOutputSymbols();
                outputSymbols6.getClass();
                ImmutableList.Builder addAll4 = addAll3.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(extractConjuncts4, (v1) -> {
                    return r5.contains(v1);
                }));
                List<Symbol> outputSymbols7 = joinNode.getLeft().getOutputSymbols();
                outputSymbols7.getClass();
                List<Symbol> outputSymbols8 = joinNode.getRight().getOutputSymbols();
                outputSymbols8.getClass();
                return ExpressionUtils.combineConjuncts(addAll4.addAll((Iterable) pullNullableConjunctsThroughOuterJoin(arrayList, (v1) -> {
                    return r5.contains(v1);
                }, (v1) -> {
                    return r5.contains(v1);
                })).build());
            default:
                throw new UnsupportedOperationException("Unknown join type: " + joinNode.getType());
        }
    }

    private static Iterable<Expression> pullNullableConjunctsThroughOuterJoin(List<Expression> list, Predicate<Symbol>... predicateArr) {
        return (Iterable) list.stream().map(expression -> {
            return DependencyExtractor.extractAll(expression).isEmpty() ? BooleanLiteral.TRUE_LITERAL : expression;
        }).map(ExpressionUtils.expressionOrNullSymbols(predicateArr)).collect(ImmutableCollectors.toImmutableList());
    }

    @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
    public Expression visitSemiJoin(SemiJoinNode semiJoinNode, Void r6) {
        return (Expression) semiJoinNode.getSource().accept(this, r6);
    }

    private Expression deriveCommonPredicates(PlanNode planNode, Function<Integer, Collection<Map.Entry<Symbol, SymbolReference>>> function) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < planNode.getSources().size(); i++) {
            arrayList.add(ImmutableSet.copyOf((Collection) ExpressionUtils.extractConjuncts(pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts(ImmutableList.builder().addAll((Iterable) function.apply(Integer.valueOf(i)).stream().filter(SYMBOL_MATCHES_EXPRESSION.negate()).map(ENTRY_TO_EQUALITY).collect(ImmutableCollectors.toImmutableList())).add((ImmutableList.Builder) planNode.getSources().get(i).accept(this, null)).build()), planNode.getOutputSymbols()))));
        }
        Iterator it2 = arrayList.iterator();
        Set set = (Set) it2.next();
        while (true) {
            Set set2 = set;
            if (!it2.hasNext()) {
                return ExpressionUtils.combineConjuncts(set2);
            }
            set = Sets.intersection(set2, (Set) it2.next());
        }
    }

    private static Expression pullExpressionThroughSymbols(Expression expression, Collection<Symbol> collection) {
        Expression rewriteExpression;
        EqualityInference createEqualityInference = EqualityInference.createEqualityInference(expression);
        ImmutableList.Builder builder = ImmutableList.builder();
        for (Expression expression2 : EqualityInference.nonInferrableConjuncts(expression)) {
            if (DeterminismEvaluator.isDeterministic(expression2) && (rewriteExpression = createEqualityInference.rewriteExpression(expression2, Predicates.in(collection))) != null) {
                builder.add((ImmutableList.Builder) rewriteExpression);
            }
        }
        builder.addAll((Iterable) createEqualityInference.generateEqualitiesPartitionedBy(Predicates.in(collection)).getScopeEqualities());
        return ExpressionUtils.combineConjuncts(builder.build());
    }
}
