/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.ComparisonStatsCalculator;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimateMath;
import com.facebook.presto.cost.StatsUtil;
import com.facebook.presto.cost.SymbolStatsEstimate;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.LiteralInterpreter;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.BetweenPredicate;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.InListExpression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalDouble;

public class FilterStatsCalculator {
    private static final double UNKNOWN_FILTER_COEFFICIENT = 0.9;
    private final Metadata metadata;

    public FilterStatsCalculator(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    public PlanNodeStatsEstimate filterStats(PlanNodeStatsEstimate statsEstimate, Expression predicate, Session session, Map<Symbol, Type> types) {
        return (PlanNodeStatsEstimate)new FilterExpressionStatsCalculatingVisitor(statsEstimate, session, types).process((Node)predicate);
    }

    public static PlanNodeStatsEstimate filterStatsForUnknownExpression(PlanNodeStatsEstimate inputStatistics) {
        return inputStatistics.mapOutputRowCount(rowCount -> rowCount * 0.9);
    }

    private class FilterExpressionStatsCalculatingVisitor
    extends AstVisitor<PlanNodeStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;
        private final Map<Symbol, Type> types;

        FilterExpressionStatsCalculatingVisitor(PlanNodeStatsEstimate input, Session session, Map<Symbol, Type> types) {
            this.input = input;
            this.session = session;
            this.types = types;
        }

        protected PlanNodeStatsEstimate visitExpression(Expression node, Void context) {
            return this.filterForUnknownExpression();
        }

        private PlanNodeStatsEstimate filterForUnknownExpression() {
            return FilterStatsCalculator.filterStatsForUnknownExpression(this.input);
        }

        private PlanNodeStatsEstimate filterForFalseExpression() {
            PlanNodeStatsEstimate.Builder falseStatsBuilder = PlanNodeStatsEstimate.builder();
            this.input.getSymbolsWithKnownStatistics().forEach(symbol -> falseStatsBuilder.addSymbolStatistics((Symbol)symbol, SymbolStatsEstimate.ZERO_STATS));
            return falseStatsBuilder.setOutputRowCount(0.0).build();
        }

        protected PlanNodeStatsEstimate visitNotExpression(NotExpression node, Void context) {
            return PlanNodeStatsEstimateMath.differenceInStats(this.input, (PlanNodeStatsEstimate)this.process((Node)node.getValue()));
        }

        protected PlanNodeStatsEstimate visitLogicalBinaryExpression(LogicalBinaryExpression node, Void context) {
            PlanNodeStatsEstimate leftStats = (PlanNodeStatsEstimate)this.process((Node)node.getLeft());
            PlanNodeStatsEstimate andStats = (PlanNodeStatsEstimate)new FilterExpressionStatsCalculatingVisitor(leftStats, this.session, this.types).process((Node)node.getRight());
            switch (node.getType()) {
                case AND: {
                    return andStats;
                }
                case OR: {
                    PlanNodeStatsEstimate rightStats = (PlanNodeStatsEstimate)this.process((Node)node.getRight());
                    PlanNodeStatsEstimate sumStats = PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues(leftStats, rightStats);
                    return PlanNodeStatsEstimateMath.differenceInNonRangeStats(sumStats, andStats);
                }
            }
            throw new IllegalStateException("Unimplemented logical binary operator expression " + node.getType());
        }

        protected PlanNodeStatsEstimate visitBooleanLiteral(BooleanLiteral node, Void context) {
            if (node.equals((Object)BooleanLiteral.TRUE_LITERAL)) {
                return this.input;
            }
            return this.filterForFalseExpression();
        }

        protected PlanNodeStatsEstimate visitIsNotNullPredicate(IsNotNullPredicate node, Void context) {
            if (node.getValue() instanceof SymbolReference) {
                Symbol symbol = Symbol.from(node.getValue());
                SymbolStatsEstimate symbolStatsEstimate = this.input.getSymbolStatistics(symbol);
                return this.input.mapOutputRowCount(rowCount -> rowCount * (1.0 - symbolStatsEstimate.getNullsFraction())).mapSymbolColumnStatistics(symbol, statsEstimate -> statsEstimate.mapNullsFraction(x -> 0.0));
            }
            return this.visitExpression((Expression)node, context);
        }

        protected PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate node, Void context) {
            if (node.getValue() instanceof SymbolReference) {
                Symbol symbol = Symbol.from(node.getValue());
                SymbolStatsEstimate symbolStatsEstimate = this.input.getSymbolStatistics(symbol);
                return this.input.mapOutputRowCount(rowCount -> rowCount * symbolStatsEstimate.getNullsFraction()).mapSymbolColumnStatistics(symbol, statsEstimate -> SymbolStatsEstimate.builder().setNullsFraction(1.0).setLowValue(Double.NaN).setHighValue(Double.NaN).setDistinctValuesCount(0.0).build());
            }
            return this.visitExpression((Expression)node, context);
        }

        protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Void context) {
            if (!(node.getValue() instanceof SymbolReference && node.getMin() instanceof Literal && node.getMax() instanceof Literal)) {
                return this.visitExpression((Expression)node, context);
            }
            SymbolStatsEstimate valueStats = this.input.getSymbolStatistics(Symbol.from(node.getValue()));
            ComparisonExpression lowerBound = new ComparisonExpression(ComparisonExpressionType.GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin());
            ComparisonExpression upperBound = new ComparisonExpression(ComparisonExpressionType.LESS_THAN_OR_EQUAL, node.getValue(), node.getMax());
            Expression transformed = Double.isInfinite(valueStats.getLowValue()) ? ExpressionUtils.and(new Expression[]{lowerBound, upperBound}) : ExpressionUtils.and(new Expression[]{upperBound, lowerBound});
            return (PlanNodeStatsEstimate)this.process((Node)transformed);
        }

        protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context) {
            if (!(node.getValue() instanceof SymbolReference) || !(node.getValueList() instanceof InListExpression)) {
                return this.visitExpression((Expression)node, context);
            }
            InListExpression inList = (InListExpression)node.getValueList();
            PlanNodeStatsEstimate statsSum = inList.getValues().stream().map(inValue -> (PlanNodeStatsEstimate)this.process((Node)new ComparisonExpression(ComparisonExpressionType.EQUAL, node.getValue(), inValue))).reduce(this.filterForFalseExpression(), PlanNodeStatsEstimateMath::addStatsAndSumDistinctValues);
            if (Double.isNaN(statsSum.getOutputRowCount())) {
                return this.visitExpression((Expression)node, context);
            }
            Symbol inValueSymbol = Symbol.from(node.getValue());
            SymbolStatsEstimate symbolStats = this.input.getSymbolStatistics(inValueSymbol);
            double notNullValuesBeforeIn = this.input.getOutputRowCount() * (1.0 - symbolStats.getNullsFraction());
            SymbolStatsEstimate newSymbolStats = statsSum.getSymbolStatistics(inValueSymbol).mapDistinctValuesCount(newDistinctValuesCount -> Double.min(newDistinctValuesCount, symbolStats.getDistinctValuesCount()));
            return this.input.mapOutputRowCount(rowCount -> Double.min(statsSum.getOutputRowCount(), notNullValuesBeforeIn)).mapSymbolColumnStatistics(inValueSymbol, oldSymbolStats -> newSymbolStats);
        }

        protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression node, Void context) {
            ComparisonExpressionType type = node.getType();
            Expression left = node.getLeft();
            Expression right = node.getRight();
            if (!(left instanceof SymbolReference) && right instanceof SymbolReference) {
                return (PlanNodeStatsEstimate)this.process((Node)new ComparisonExpression(type.flip(), right, left));
            }
            if (left instanceof SymbolReference && right instanceof Literal) {
                Symbol symbol = Symbol.from(left);
                OptionalDouble literal = this.doubleValueFromLiteral(this.types.get(symbol), (Literal)right);
                return ComparisonStatsCalculator.comparisonSymbolToLiteralStats(this.input, symbol, literal, type);
            }
            if (right instanceof SymbolReference) {
                return ComparisonStatsCalculator.comparisonSymbolToSymbolStats(this.input, Symbol.from(left), Symbol.from(right), type);
            }
            return FilterStatsCalculator.filterStatsForUnknownExpression(this.input);
        }

        private OptionalDouble doubleValueFromLiteral(Type type, Literal literal) {
            Object literalValue = LiteralInterpreter.evaluate(FilterStatsCalculator.this.metadata, this.session.toConnectorSession(), (Expression)literal);
            return StatsUtil.toStatsRepresentation(FilterStatsCalculator.this.metadata, this.session, type, literalValue);
        }
    }
}

