/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.ComparisonStatsCalculator;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimateMath;
import com.facebook.presto.cost.ScalarStatsCalculator;
import com.facebook.presto.cost.StatsNormalizer;
import com.facebook.presto.cost.StatsUtil;
import com.facebook.presto.cost.SymbolStatsEstimate;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.analyzer.Scope;
import com.facebook.presto.sql.planner.LiteralInterpreter;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.BetweenPredicate;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.InListExpression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalDouble;
import javax.annotation.Nullable;

public class FilterStatsCalculator {
    static final double UNKNOWN_FILTER_COEFFICIENT = 0.9;
    private final Metadata metadata;
    private final ScalarStatsCalculator scalarStatsCalculator;
    private final StatsNormalizer normalizer;

    public FilterStatsCalculator(Metadata metadata, ScalarStatsCalculator scalarStatsCalculator, StatsNormalizer normalizer) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.scalarStatsCalculator = Objects.requireNonNull(scalarStatsCalculator, "scalarStatsCalculator is null");
        this.normalizer = Objects.requireNonNull(normalizer, "normalizer is null");
    }

    public PlanNodeStatsEstimate filterStats(PlanNodeStatsEstimate statsEstimate, Expression predicate, Session session, Map<Symbol, Type> types) {
        return ((Optional)new FilterExpressionStatsCalculatingVisitor(statsEstimate, session, types).process((Node)predicate)).orElseGet(() -> this.normalizer.normalize(FilterStatsCalculator.filterStatsForUnknownExpression(statsEstimate), types));
    }

    private static PlanNodeStatsEstimate filterStatsForUnknownExpression(PlanNodeStatsEstimate inputStatistics) {
        return inputStatistics.mapOutputRowCount(rowCount -> rowCount * 0.9);
    }

    private class FilterExpressionStatsCalculatingVisitor
    extends AstVisitor<Optional<PlanNodeStatsEstimate>, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;
        private final Map<Symbol, Type> types;

        FilterExpressionStatsCalculatingVisitor(PlanNodeStatsEstimate input, Session session, Map<Symbol, Type> types) {
            this.input = input;
            this.session = session;
            this.types = types;
        }

        public Optional<PlanNodeStatsEstimate> process(Node node, @Nullable Void context) {
            return ((Optional)super.process(node, (Object)context)).map(estimate -> FilterStatsCalculator.this.normalizer.normalize((PlanNodeStatsEstimate)estimate, this.types));
        }

        protected Optional<PlanNodeStatsEstimate> visitExpression(Expression node, Void context) {
            return Optional.empty();
        }

        private Optional<PlanNodeStatsEstimate> filterForFalseExpression() {
            PlanNodeStatsEstimate.Builder falseStatsBuilder = PlanNodeStatsEstimate.builder();
            this.input.getSymbolsWithKnownStatistics().forEach(symbol -> falseStatsBuilder.addSymbolStatistics((Symbol)symbol, SymbolStatsEstimate.ZERO_STATS));
            return Optional.of(falseStatsBuilder.setOutputRowCount(0.0).build());
        }

        protected Optional<PlanNodeStatsEstimate> visitNotExpression(NotExpression node, Void context) {
            if (node.getValue() instanceof IsNullPredicate) {
                return (Optional)this.process((Node)new IsNotNullPredicate(((IsNullPredicate)node.getValue()).getValue()));
            }
            return ((Optional)this.process((Node)node.getValue())).map(childStats -> PlanNodeStatsEstimateMath.differenceInStats(this.input, childStats));
        }

        protected Optional<PlanNodeStatsEstimate> visitLogicalBinaryExpression(LogicalBinaryExpression node, Void context) {
            switch (node.getType()) {
                case AND: {
                    return this.visitLogicalBinaryAnd(node.getLeft(), node.getRight());
                }
                case OR: {
                    return this.visitLogicalBinaryOr(node.getLeft(), node.getRight());
                }
            }
            throw new IllegalStateException("Unimplemented logical binary operator expression " + node.getType());
        }

        private Optional<PlanNodeStatsEstimate> visitLogicalBinaryAnd(Expression left, Expression right) {
            Optional leftStats = (Optional)this.process((Node)left);
            if (leftStats.isPresent()) {
                Optional andStats = (Optional)new FilterExpressionStatsCalculatingVisitor((PlanNodeStatsEstimate)leftStats.get(), this.session, this.types).process((Node)right);
                if (andStats.isPresent()) {
                    return andStats;
                }
                return leftStats.map(x$0 -> FilterStatsCalculator.filterStatsForUnknownExpression(x$0));
            }
            Optional rightStats = (Optional)this.process((Node)right);
            return rightStats.map(x$0 -> FilterStatsCalculator.filterStatsForUnknownExpression(x$0));
        }

        private Optional<PlanNodeStatsEstimate> visitLogicalBinaryOr(Expression left, Expression right) {
            Optional leftStats = (Optional)this.process((Node)left);
            if (!leftStats.isPresent()) {
                return Optional.empty();
            }
            Optional rightStats = (Optional)this.process((Node)right);
            if (!rightStats.isPresent()) {
                return Optional.empty();
            }
            Optional andStats = (Optional)new FilterExpressionStatsCalculatingVisitor((PlanNodeStatsEstimate)leftStats.get(), this.session, this.types).process((Node)right);
            if (!andStats.isPresent()) {
                return Optional.empty();
            }
            PlanNodeStatsEstimate sumStats = PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues((PlanNodeStatsEstimate)leftStats.get(), (PlanNodeStatsEstimate)rightStats.get());
            return Optional.of(PlanNodeStatsEstimateMath.differenceInNonRangeStats(sumStats, (PlanNodeStatsEstimate)andStats.get()));
        }

        protected Optional<PlanNodeStatsEstimate> visitBooleanLiteral(BooleanLiteral node, Void context) {
            if (node.equals((Object)BooleanLiteral.TRUE_LITERAL)) {
                return Optional.of(this.input);
            }
            return this.filterForFalseExpression();
        }

        protected Optional<PlanNodeStatsEstimate> visitIsNotNullPredicate(IsNotNullPredicate node, Void context) {
            if (node.getValue() instanceof SymbolReference) {
                Symbol symbol = Symbol.from(node.getValue());
                SymbolStatsEstimate symbolStatsEstimate = this.input.getSymbolStatistics(symbol);
                return Optional.of(this.input.mapOutputRowCount(rowCount -> rowCount * (1.0 - symbolStatsEstimate.getNullsFraction())).mapSymbolColumnStatistics(symbol, statsEstimate -> statsEstimate.mapNullsFraction(x -> 0.0)));
            }
            return this.visitExpression((Expression)node, context);
        }

        protected Optional<PlanNodeStatsEstimate> visitIsNullPredicate(IsNullPredicate node, Void context) {
            if (node.getValue() instanceof SymbolReference) {
                Symbol symbol = Symbol.from(node.getValue());
                SymbolStatsEstimate symbolStatsEstimate = this.input.getSymbolStatistics(symbol);
                return Optional.of(this.input.mapOutputRowCount(rowCount -> rowCount * symbolStatsEstimate.getNullsFraction()).mapSymbolColumnStatistics(symbol, statsEstimate -> SymbolStatsEstimate.builder().setNullsFraction(1.0).setLowValue(Double.NaN).setHighValue(Double.NaN).setDistinctValuesCount(0.0).build()));
            }
            return this.visitExpression((Expression)node, context);
        }

        protected Optional<PlanNodeStatsEstimate> visitBetweenPredicate(BetweenPredicate node, Void context) {
            if (!(node.getValue() instanceof SymbolReference)) {
                return this.visitExpression((Expression)node, context);
            }
            if (!(node.getMin() instanceof Literal) && !this.isSingleValue(this.getExpressionStats(node.getMin()))) {
                return this.visitExpression((Expression)node, context);
            }
            if (!(node.getMax() instanceof Literal) && !this.isSingleValue(this.getExpressionStats(node.getMax()))) {
                return this.visitExpression((Expression)node, context);
            }
            SymbolStatsEstimate valueStats = this.input.getSymbolStatistics(Symbol.from(node.getValue()));
            ComparisonExpression lowerBound = new ComparisonExpression(ComparisonExpressionType.GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin());
            ComparisonExpression upperBound = new ComparisonExpression(ComparisonExpressionType.LESS_THAN_OR_EQUAL, node.getValue(), node.getMax());
            Expression transformed = Double.isInfinite(valueStats.getLowValue()) ? ExpressionUtils.and(new Expression[]{lowerBound, upperBound}) : ExpressionUtils.and(new Expression[]{upperBound, lowerBound});
            return (Optional)this.process((Node)transformed);
        }

        protected Optional<PlanNodeStatsEstimate> visitInPredicate(InPredicate node, Void context) {
            if (!(node.getValueList() instanceof InListExpression)) {
                return Optional.empty();
            }
            InListExpression inList = (InListExpression)node.getValueList();
            ImmutableList valuesEqualityStats = (ImmutableList)inList.getValues().stream().map(inValue -> (Optional)this.process((Node)new ComparisonExpression(ComparisonExpressionType.EQUAL, node.getValue(), inValue))).collect(ImmutableList.toImmutableList());
            if (!valuesEqualityStats.stream().allMatch(Optional::isPresent)) {
                return Optional.empty();
            }
            PlanNodeStatsEstimate statsSum = valuesEqualityStats.stream().map(Optional::get).reduce(this.filterForFalseExpression().get(), PlanNodeStatsEstimateMath::addStatsAndSumDistinctValues);
            if (Double.isNaN(statsSum.getOutputRowCount())) {
                return Optional.empty();
            }
            Optional<Symbol> inValueSymbol = this.asSymbol(node.getValue());
            SymbolStatsEstimate inValueStats = this.getExpressionStats(node.getValue());
            if (Objects.equals(inValueStats, SymbolStatsEstimate.UNKNOWN_STATS)) {
                return Optional.empty();
            }
            double notNullValuesBeforeIn = this.input.getOutputRowCount() * (1.0 - inValueStats.getNullsFraction());
            PlanNodeStatsEstimate estimate = this.input.mapOutputRowCount(rowCount -> Double.min(statsSum.getOutputRowCount(), notNullValuesBeforeIn));
            if (inValueSymbol.isPresent()) {
                SymbolStatsEstimate newSymbolStats = statsSum.getSymbolStatistics(inValueSymbol.get()).mapDistinctValuesCount(newDistinctValuesCount -> Double.min(newDistinctValuesCount, inValueStats.getDistinctValuesCount()));
                estimate = estimate.mapSymbolColumnStatistics(inValueSymbol.get(), oldSymbolStats -> newSymbolStats);
            }
            return Optional.of(estimate);
        }

        protected Optional<PlanNodeStatsEstimate> visitComparisonExpression(ComparisonExpression node, Void context) {
            ComparisonExpressionType type = node.getType();
            Expression left = node.getLeft();
            Expression right = node.getRight();
            Preconditions.checkArgument((!(left instanceof Literal) || !(right instanceof Literal) ? 1 : 0) != 0, (Object)"Literal-to-literal not supported here, should be eliminated earlier");
            if (!(left instanceof SymbolReference) && right instanceof SymbolReference) {
                return (Optional)this.process((Node)new ComparisonExpression(type.flip(), right, left));
            }
            if (left instanceof Literal && !(right instanceof Literal)) {
                return (Optional)this.process((Node)new ComparisonExpression(type.flip(), right, left));
            }
            Optional<Symbol> leftSymbol = this.asSymbol(left);
            SymbolStatsEstimate leftStats = this.getExpressionStats(left);
            if (Objects.equals(leftStats, SymbolStatsEstimate.UNKNOWN_STATS)) {
                return this.visitExpression((Expression)node, context);
            }
            if (right instanceof Literal) {
                OptionalDouble literal = this.doubleValueFromLiteral(this.getType(left), (Literal)right);
                return ComparisonStatsCalculator.comparisonExpressionToLiteralStats(this.input, leftSymbol, leftStats, literal, type);
            }
            Optional<Symbol> rightSymbol = this.asSymbol(right);
            SymbolStatsEstimate rightStats = this.getExpressionStats(right);
            if (Objects.equals(rightStats, SymbolStatsEstimate.UNKNOWN_STATS)) {
                return this.visitExpression((Expression)node, context);
            }
            if (left instanceof SymbolReference && Objects.equals(left, right)) {
                return (Optional)this.process((Node)new IsNotNullPredicate(left));
            }
            if (this.isSingleValue(rightStats)) {
                OptionalDouble value = Double.isNaN(rightStats.getLowValue()) ? OptionalDouble.empty() : OptionalDouble.of(rightStats.getLowValue());
                return ComparisonStatsCalculator.comparisonExpressionToLiteralStats(this.input, leftSymbol, leftStats, value, type);
            }
            return ComparisonStatsCalculator.comparisonExpressionToExpressionStats(this.input, leftSymbol, leftStats, rightSymbol, rightStats, type);
        }

        private Optional<Symbol> asSymbol(Expression expression) {
            if (expression instanceof SymbolReference) {
                return Optional.of(Symbol.from(expression));
            }
            return Optional.empty();
        }

        private boolean isSingleValue(SymbolStatsEstimate stats) {
            return stats.getDistinctValuesCount() == 1.0 && Double.compare(stats.getLowValue(), stats.getHighValue()) == 0 && !Double.isInfinite(stats.getLowValue());
        }

        private Type getType(Expression expression) {
            return this.asSymbol(expression).map(symbol -> Objects.requireNonNull(this.types.get(symbol), () -> String.format("No type for symbol %s", symbol))).orElseGet(() -> {
                ExpressionAnalyzer expressionAnalyzer = ExpressionAnalyzer.createWithoutSubqueries(FilterStatsCalculator.this.metadata.getFunctionRegistry(), FilterStatsCalculator.this.metadata.getTypeManager(), this.session, this.types, (List<Expression>)ImmutableList.of(), node -> new IllegalStateException("Unexpected Subquery"), false);
                Type type = expressionAnalyzer.analyze(expression, Scope.create());
                return type;
            });
        }

        private SymbolStatsEstimate getExpressionStats(Expression expression) {
            return this.asSymbol(expression).map(symbol -> Objects.requireNonNull(this.input.getSymbolStatistics((Symbol)symbol), () -> String.format("No statistics for symbol %s", symbol))).orElseGet(() -> FilterStatsCalculator.this.scalarStatsCalculator.calculate(expression, this.input, this.session));
        }

        private OptionalDouble doubleValueFromLiteral(Type type, Literal literal) {
            Object literalValue = LiteralInterpreter.evaluate(FilterStatsCalculator.this.metadata, this.session.toConnectorSession(), (Expression)literal);
            return StatsUtil.toStatsRepresentation(FilterStatsCalculator.this.metadata, this.session, type, literalValue);
        }
    }
}

