/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatisticRange;
import com.facebook.presto.cost.VariableStatsEstimate;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.util.MoreMath;
import java.util.Optional;
import java.util.OptionalDouble;

public final class ComparisonStatsCalculator {
    private ComparisonStatsCalculator() {
    }

    public static PlanNodeStatsEstimate estimateExpressionToLiteralComparison(PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate expressionStatistics, Optional<VariableReferenceExpression> expressionVariable, OptionalDouble literalValue, ComparisonExpression.Operator operator) {
        switch (operator) {
            case EQUAL: {
                return ComparisonStatsCalculator.estimateExpressionEqualToLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue);
            }
            case NOT_EQUAL: {
                return ComparisonStatsCalculator.estimateExpressionNotEqualToLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue);
            }
            case LESS_THAN: 
            case LESS_THAN_OR_EQUAL: {
                return ComparisonStatsCalculator.estimateExpressionLessThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue);
            }
            case GREATER_THAN: 
            case GREATER_THAN_OR_EQUAL: {
                return ComparisonStatsCalculator.estimateExpressionGreaterThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue);
            }
            case IS_DISTINCT_FROM: {
                return PlanNodeStatsEstimate.unknown();
            }
        }
        throw new IllegalArgumentException("Unexpected comparison operator: " + operator);
    }

    private static PlanNodeStatsEstimate estimateExpressionEqualToLiteral(PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate expressionStatistics, Optional<VariableReferenceExpression> expressionVariable, OptionalDouble literalValue) {
        StatisticRange filterRange = literalValue.isPresent() ? new StatisticRange(literalValue.getAsDouble(), literalValue.getAsDouble(), 1.0) : new StatisticRange(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0);
        return ComparisonStatsCalculator.estimateFilterRange(inputStatistics, expressionStatistics, expressionVariable, filterRange);
    }

    private static PlanNodeStatsEstimate estimateExpressionNotEqualToLiteral(PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate expressionStatistics, Optional<VariableReferenceExpression> expressionVariable, OptionalDouble literalValue) {
        StatisticRange expressionRange = StatisticRange.from(expressionStatistics);
        StatisticRange filterRange = literalValue.isPresent() ? new StatisticRange(literalValue.getAsDouble(), literalValue.getAsDouble(), 1.0) : new StatisticRange(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0);
        StatisticRange intersectRange = expressionRange.intersect(filterRange);
        double filterFactor = 1.0 - expressionRange.overlapPercentWith(intersectRange);
        PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics);
        estimate.setOutputRowCount(filterFactor * (1.0 - expressionStatistics.getNullsFraction()) * inputStatistics.getOutputRowCount());
        if (expressionVariable.isPresent()) {
            VariableStatsEstimate symbolNewEstimate = VariableStatsEstimate.buildFrom(expressionStatistics).setNullsFraction(0.0).setDistinctValuesCount(MoreMath.max(expressionStatistics.getDistinctValuesCount() - 1.0, 0.0)).build();
            estimate = estimate.addVariableStatistics(expressionVariable.get(), symbolNewEstimate);
        }
        return estimate.build();
    }

    private static PlanNodeStatsEstimate estimateExpressionLessThanLiteral(PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate expressionStatistics, Optional<VariableReferenceExpression> expressionVariable, OptionalDouble literalValue) {
        StatisticRange filterRange = new StatisticRange(Double.NEGATIVE_INFINITY, literalValue.orElse(Double.POSITIVE_INFINITY), Double.NaN);
        return ComparisonStatsCalculator.estimateFilterRange(inputStatistics, expressionStatistics, expressionVariable, filterRange);
    }

    private static PlanNodeStatsEstimate estimateExpressionGreaterThanLiteral(PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate expressionStatistics, Optional<VariableReferenceExpression> expressionVariable, OptionalDouble literalValue) {
        StatisticRange filterRange = new StatisticRange(literalValue.orElse(Double.NEGATIVE_INFINITY), Double.POSITIVE_INFINITY, Double.NaN);
        return ComparisonStatsCalculator.estimateFilterRange(inputStatistics, expressionStatistics, expressionVariable, filterRange);
    }

    private static PlanNodeStatsEstimate estimateFilterRange(PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate expressionStatistics, Optional<VariableReferenceExpression> expressionVariable, StatisticRange filterRange) {
        StatisticRange expressionRange = StatisticRange.from(expressionStatistics);
        StatisticRange intersectRange = expressionRange.intersect(filterRange);
        double filterFactor = expressionRange.overlapPercentWith(intersectRange);
        PlanNodeStatsEstimate estimate = inputStatistics.mapOutputRowCount(rowCount -> filterFactor * (1.0 - expressionStatistics.getNullsFraction()) * rowCount);
        if (expressionVariable.isPresent()) {
            VariableStatsEstimate symbolNewEstimate = VariableStatsEstimate.builder().setAverageRowSize(expressionStatistics.getAverageRowSize()).setStatisticsRange(intersectRange).setNullsFraction(0.0).build();
            estimate = estimate.mapVariableColumnStatistics(expressionVariable.get(), oldStats -> symbolNewEstimate);
        }
        return estimate;
    }

    public static PlanNodeStatsEstimate estimateExpressionToExpressionComparison(PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate leftExpressionStatistics, Optional<VariableReferenceExpression> leftExpressionVariable, VariableStatsEstimate rightExpressionStatistics, Optional<VariableReferenceExpression> rightExpressionVariable, ComparisonExpression.Operator operator) {
        switch (operator) {
            case EQUAL: {
                return ComparisonStatsCalculator.estimateExpressionEqualToExpression(inputStatistics, leftExpressionStatistics, leftExpressionVariable, rightExpressionStatistics, rightExpressionVariable);
            }
            case NOT_EQUAL: {
                return ComparisonStatsCalculator.estimateExpressionNotEqualToExpression(inputStatistics, leftExpressionStatistics, leftExpressionVariable, rightExpressionStatistics, rightExpressionVariable);
            }
            case LESS_THAN: 
            case LESS_THAN_OR_EQUAL: 
            case GREATER_THAN: 
            case GREATER_THAN_OR_EQUAL: 
            case IS_DISTINCT_FROM: {
                return PlanNodeStatsEstimate.unknown();
            }
        }
        throw new IllegalArgumentException("Unexpected comparison operator: " + operator);
    }

    private static PlanNodeStatsEstimate estimateExpressionEqualToExpression(PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate leftExpressionStatistics, Optional<VariableReferenceExpression> leftExpressionVariable, VariableStatsEstimate rightExpressionStatistics, Optional<VariableReferenceExpression> rightExpressionVariable) {
        if (Double.isNaN(leftExpressionStatistics.getDistinctValuesCount()) || Double.isNaN(rightExpressionStatistics.getDistinctValuesCount())) {
            return PlanNodeStatsEstimate.unknown();
        }
        StatisticRange leftExpressionRange = StatisticRange.from(leftExpressionStatistics);
        StatisticRange rightExpressionRange = StatisticRange.from(rightExpressionStatistics);
        StatisticRange intersect = leftExpressionRange.intersect(rightExpressionRange);
        double nullsFilterFactor = (1.0 - leftExpressionStatistics.getNullsFraction()) * (1.0 - rightExpressionStatistics.getNullsFraction());
        double leftNdv = leftExpressionRange.getDistinctValuesCount();
        double rightNdv = rightExpressionRange.getDistinctValuesCount();
        double filterFactor = 1.0 / MoreMath.max(leftNdv, rightNdv, 1.0);
        double retainedNdv = MoreMath.min(leftNdv, rightNdv);
        PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics).setOutputRowCount(inputStatistics.getOutputRowCount() * nullsFilterFactor * filterFactor);
        VariableStatsEstimate equalityStats = VariableStatsEstimate.builder().setAverageRowSize(ComparisonStatsCalculator.averageExcludingNaNs(leftExpressionStatistics.getAverageRowSize(), rightExpressionStatistics.getAverageRowSize())).setNullsFraction(0.0).setStatisticsRange(intersect).setDistinctValuesCount(retainedNdv).build();
        leftExpressionVariable.ifPresent(variable -> estimate.addVariableStatistics((VariableReferenceExpression)variable, equalityStats));
        rightExpressionVariable.ifPresent(variable -> estimate.addVariableStatistics((VariableReferenceExpression)variable, equalityStats));
        return estimate.build();
    }

    private static PlanNodeStatsEstimate estimateExpressionNotEqualToExpression(PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate leftExpressionStatistics, Optional<VariableReferenceExpression> leftExpressionVariable, VariableStatsEstimate rightExpressionStatistics, Optional<VariableReferenceExpression> rightExpressionVariable) {
        VariableStatsEstimate rightNullsFiltered;
        VariableStatsEstimate leftNullsFiltered;
        double nullsFilterFactor = (1.0 - leftExpressionStatistics.getNullsFraction()) * (1.0 - rightExpressionStatistics.getNullsFraction());
        PlanNodeStatsEstimate inputNullsFiltered = inputStatistics.mapOutputRowCount(size -> size * nullsFilterFactor);
        PlanNodeStatsEstimate equalityStats = ComparisonStatsCalculator.estimateExpressionEqualToExpression(inputNullsFiltered, leftNullsFiltered = leftExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0), leftExpressionVariable, rightNullsFiltered = rightExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0), rightExpressionVariable);
        if (equalityStats.isOutputRowCountUnknown()) {
            return PlanNodeStatsEstimate.unknown();
        }
        PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(inputNullsFiltered);
        double equalityFilterFactor = equalityStats.getOutputRowCount() / inputNullsFiltered.getOutputRowCount();
        if (!Double.isFinite(equalityFilterFactor)) {
            equalityFilterFactor = 0.0;
        }
        result.setOutputRowCount(inputNullsFiltered.getOutputRowCount() * (1.0 - equalityFilterFactor));
        leftExpressionVariable.ifPresent(symbol -> result.addVariableStatistics((VariableReferenceExpression)symbol, leftNullsFiltered));
        rightExpressionVariable.ifPresent(symbol -> result.addVariableStatistics((VariableReferenceExpression)symbol, rightNullsFiltered));
        return result.build();
    }

    private static double averageExcludingNaNs(double first, double second) {
        if (Double.isNaN(first) && Double.isNaN(second)) {
            return Double.NaN;
        }
        if (!Double.isNaN(first) && !Double.isNaN(second)) {
            return (first + second) / 2.0;
        }
        return MoreMath.firstNonNaN(first, second);
    }
}

