/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner;

import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.DependencyExtractor;
import com.facebook.presto.sql.planner.DeterminismEvaluator;
import com.facebook.presto.sql.planner.DomainTranslator;
import com.facebook.presto.sql.planner.DomainUtils;
import com.facebook.presto.sql.planner.EqualityInference;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.DistinctLimitNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanVisitor;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SortNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.TopNNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.util.ImmutableCollectors;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;

public class EffectivePredicateExtractor
extends PlanVisitor<Void, Expression> {
    private static final Predicate<Map.Entry<Symbol, ? extends Expression>> SYMBOL_MATCHES_EXPRESSION = entry -> ((Expression)entry.getValue()).equals((Object)((Symbol)entry.getKey()).toSymbolReference());
    private static final Function<Map.Entry<Symbol, ? extends Expression>, Expression> ENTRY_TO_EQUALITY = entry -> {
        SymbolReference reference = ((Symbol)entry.getKey()).toSymbolReference();
        Expression expression = (Expression)entry.getValue();
        return new ComparisonExpression(ComparisonExpressionType.EQUAL, (Expression)reference, expression);
    };
    private final Map<Symbol, Type> symbolTypes;

    public static Expression extract(PlanNode node, Map<Symbol, Type> symbolTypes) {
        return node.accept(new EffectivePredicateExtractor(symbolTypes), null);
    }

    public EffectivePredicateExtractor(Map<Symbol, Type> symbolTypes) {
        this.symbolTypes = symbolTypes;
    }

    @Override
    protected Expression visitPlan(PlanNode node, Void context) {
        return BooleanLiteral.TRUE_LITERAL;
    }

    @Override
    public Expression visitAggregation(AggregationNode node, Void context) {
        if (node.getGroupingKeys().isEmpty()) {
            return BooleanLiteral.TRUE_LITERAL;
        }
        Expression underlyingPredicate = node.getSource().accept(this, context);
        return EffectivePredicateExtractor.pullExpressionThroughSymbols(underlyingPredicate, node.getGroupingKeys());
    }

    @Override
    public Expression visitFilter(FilterNode node, Void context) {
        Expression underlyingPredicate = node.getSource().accept(this, context);
        Expression predicate = node.getPredicate();
        predicate = ExpressionUtils.stripNonDeterministicConjuncts(predicate);
        return ExpressionUtils.combineConjuncts(predicate, underlyingPredicate);
    }

    @Override
    public Expression visitExchange(ExchangeNode node, Void context) {
        return this.deriveCommonPredicates(node, source -> {
            HashMap<Symbol, SymbolReference> mappings = new HashMap<Symbol, SymbolReference>();
            for (int i = 0; i < node.getInputs().get((int)source).size(); ++i) {
                mappings.put(node.getOutputSymbols().get(i), node.getInputs().get((int)source).get(i).toSymbolReference());
            }
            return mappings.entrySet();
        });
    }

    @Override
    public Expression visitProject(ProjectNode node, Void context) {
        Expression underlyingPredicate = node.getSource().accept(this, context);
        List projectionEqualities = (List)node.getAssignments().entrySet().stream().filter(SYMBOL_MATCHES_EXPRESSION.negate()).map(ENTRY_TO_EQUALITY).collect(ImmutableCollectors.toImmutableList());
        return EffectivePredicateExtractor.pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().addAll((Iterable)projectionEqualities).add((Object)underlyingPredicate).build()), node.getOutputSymbols());
    }

    @Override
    public Expression visitTopN(TopNNode node, Void context) {
        return node.getSource().accept(this, context);
    }

    @Override
    public Expression visitLimit(LimitNode node, Void context) {
        return node.getSource().accept(this, context);
    }

    @Override
    public Expression visitDistinctLimit(DistinctLimitNode node, Void context) {
        return node.getSource().accept(this, context);
    }

    @Override
    public Expression visitTableScan(TableScanNode node, Void context) {
        ImmutableBiMap assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse();
        return DomainTranslator.toPredicate((TupleDomain<Symbol>)EffectivePredicateExtractor.spanTupleDomain(node.getCurrentConstraint()).transform(((Map)assignments)::get));
    }

    private static TupleDomain<ColumnHandle> spanTupleDomain(TupleDomain<ColumnHandle> tupleDomain) {
        if (tupleDomain.isNone()) {
            return tupleDomain;
        }
        Map spannedDomains = Maps.transformValues((Map)((Map)tupleDomain.getDomains().get()), DomainUtils::simplifyDomain);
        return TupleDomain.withColumnDomains((Map)spannedDomains);
    }

    @Override
    public Expression visitSort(SortNode node, Void context) {
        return node.getSource().accept(this, context);
    }

    @Override
    public Expression visitWindow(WindowNode node, Void context) {
        return node.getSource().accept(this, context);
    }

    @Override
    public Expression visitUnion(UnionNode node, Void context) {
        return this.deriveCommonPredicates(node, source -> node.outputSymbolMap((int)source).entries());
    }

    @Override
    public Expression visitJoin(JoinNode node, Void context) {
        Expression leftPredicate = node.getLeft().accept(this, context);
        Expression rightPredicate = node.getRight().accept(this, context);
        ArrayList<Expression> joinConjuncts = new ArrayList<Expression>();
        for (JoinNode.EquiJoinClause clause : node.getCriteria()) {
            joinConjuncts.add((Expression)new ComparisonExpression(ComparisonExpressionType.EQUAL, (Expression)clause.getLeft().toSymbolReference(), (Expression)clause.getRight().toSymbolReference()));
        }
        switch (node.getType()) {
            case INNER: {
                return ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().add((Object)leftPredicate).add((Object)rightPredicate).addAll(joinConjuncts).build());
            }
            case LEFT: {
                Predicate[] predicateArray = new Predicate[1];
                predicateArray[0] = node.getRight().getOutputSymbols()::contains;
                Predicate[] predicateArray2 = new Predicate[1];
                predicateArray2[0] = node.getRight().getOutputSymbols()::contains;
                return ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().add((Object)leftPredicate).addAll(this.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(rightPredicate), predicateArray)).addAll(this.pullNullableConjunctsThroughOuterJoin(joinConjuncts, predicateArray2)).build());
            }
            case RIGHT: {
                Predicate[] predicateArray = new Predicate[1];
                predicateArray[0] = node.getLeft().getOutputSymbols()::contains;
                Predicate[] predicateArray3 = new Predicate[1];
                predicateArray3[0] = node.getLeft().getOutputSymbols()::contains;
                return ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().add((Object)rightPredicate).addAll(this.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(leftPredicate), predicateArray)).addAll(this.pullNullableConjunctsThroughOuterJoin(joinConjuncts, predicateArray3)).build());
            }
            case FULL: {
                Predicate[] predicateArray = new Predicate[1];
                predicateArray[0] = node.getLeft().getOutputSymbols()::contains;
                Predicate[] predicateArray4 = new Predicate[1];
                predicateArray4[0] = node.getRight().getOutputSymbols()::contains;
                Predicate[] predicateArray5 = new Predicate[2];
                predicateArray5[0] = node.getLeft().getOutputSymbols()::contains;
                predicateArray5[1] = node.getRight().getOutputSymbols()::contains;
                return ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().addAll(this.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(leftPredicate), predicateArray)).addAll(this.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(rightPredicate), predicateArray4)).addAll(this.pullNullableConjunctsThroughOuterJoin(joinConjuncts, predicateArray5)).build());
            }
        }
        throw new UnsupportedOperationException("Unknown join type: " + (Object)((Object)node.getType()));
    }

    private Iterable<Expression> pullNullableConjunctsThroughOuterJoin(List<Expression> conjuncts, Predicate<Symbol> ... nullSymbolScopes) {
        return (Iterable)conjuncts.stream().map(expression -> DependencyExtractor.extractAll(expression).isEmpty() ? BooleanLiteral.TRUE_LITERAL : expression).map(ExpressionUtils.expressionOrNullSymbols(nullSymbolScopes)).collect(ImmutableCollectors.toImmutableList());
    }

    @Override
    public Expression visitSemiJoin(SemiJoinNode node, Void context) {
        return node.getSource().accept(this, context);
    }

    private Expression deriveCommonPredicates(PlanNode node, Function<Integer, Collection<Map.Entry<Symbol, SymbolReference>>> mapping) {
        ArrayList<ImmutableSet> sourceOutputConjuncts = new ArrayList<ImmutableSet>();
        for (int i = 0; i < node.getSources().size(); ++i) {
            Expression underlyingPredicate = node.getSources().get(i).accept(this, null);
            List equalities = (List)mapping.apply(i).stream().filter(SYMBOL_MATCHES_EXPRESSION.negate()).map(ENTRY_TO_EQUALITY).collect(ImmutableCollectors.toImmutableList());
            sourceOutputConjuncts.add(ImmutableSet.copyOf(ExpressionUtils.extractConjuncts(EffectivePredicateExtractor.pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().addAll((Iterable)equalities).add((Object)underlyingPredicate).build()), node.getOutputSymbols()))));
        }
        Iterator iterator = sourceOutputConjuncts.iterator();
        Set potentialOutputConjuncts = (Set)iterator.next();
        while (iterator.hasNext()) {
            potentialOutputConjuncts = Sets.intersection((Set)potentialOutputConjuncts, (Set)((Set)iterator.next()));
        }
        return ExpressionUtils.combineConjuncts(potentialOutputConjuncts);
    }

    private static Expression pullExpressionThroughSymbols(Expression expression, Collection<Symbol> symbols) {
        EqualityInference equalityInference = EqualityInference.createEqualityInference(expression);
        ImmutableList.Builder effectiveConjuncts = ImmutableList.builder();
        for (Expression conjunct : EqualityInference.nonInferrableConjuncts(expression)) {
            Expression rewritten;
            if (!DeterminismEvaluator.isDeterministic(conjunct) || (rewritten = equalityInference.rewriteExpression(conjunct, (com.google.common.base.Predicate<Symbol>)Predicates.in(symbols))) == null) continue;
            effectiveConjuncts.add((Object)rewritten);
        }
        effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy((com.google.common.base.Predicate<Symbol>)Predicates.in(symbols)).getScopeEqualities());
        return ExpressionUtils.combineConjuncts((Collection<Expression>)effectiveConjuncts.build());
    }
}

