/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.cost.FilterStatsCalculator;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimateMath;
import com.facebook.presto.cost.StatisticRange;
import com.facebook.presto.cost.SymbolStatsEstimate;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import java.util.OptionalDouble;

public final class ComparisonStatsCalculator {
    private ComparisonStatsCalculator() {
    }

    public static PlanNodeStatsEstimate comparisonSymbolToLiteralStats(PlanNodeStatsEstimate inputStatistics, Symbol symbol, OptionalDouble doubleLiteral, ComparisonExpressionType type) {
        switch (type) {
            case EQUAL: {
                return ComparisonStatsCalculator.symbolToLiteralEquality(inputStatistics, symbol, doubleLiteral);
            }
            case NOT_EQUAL: {
                return ComparisonStatsCalculator.symbolToLiteralNonEquality(inputStatistics, symbol, doubleLiteral);
            }
            case LESS_THAN: 
            case LESS_THAN_OR_EQUAL: {
                return ComparisonStatsCalculator.symbolToLiteralLessThan(inputStatistics, symbol, doubleLiteral);
            }
            case GREATER_THAN: 
            case GREATER_THAN_OR_EQUAL: {
                return ComparisonStatsCalculator.symbolToLiteralGreaterThan(inputStatistics, symbol, doubleLiteral);
            }
        }
        return FilterStatsCalculator.filterStatsForUnknownExpression(inputStatistics);
    }

    private static PlanNodeStatsEstimate symbolToLiteralRangeComparison(PlanNodeStatsEstimate inputStatistics, Symbol symbol, StatisticRange literalRange) {
        SymbolStatsEstimate symbolStats = inputStatistics.getSymbolStatistics(symbol);
        StatisticRange range = StatisticRange.from(symbolStats);
        StatisticRange intersectRange = range.intersect(literalRange);
        double filterFactor = range.overlapPercentWith(intersectRange);
        SymbolStatsEstimate symbolNewEstimate = SymbolStatsEstimate.builder().setAverageRowSize(symbolStats.getAverageRowSize()).setStatisticsRange(intersectRange).setNullsFraction(0.0).build();
        return inputStatistics.mapOutputRowCount(rowCount -> filterFactor * (1.0 - symbolStats.getNullsFraction()) * rowCount).mapSymbolColumnStatistics(symbol, oldStats -> symbolNewEstimate);
    }

    private static PlanNodeStatsEstimate symbolToLiteralEquality(PlanNodeStatsEstimate inputStatistics, Symbol symbol, OptionalDouble literal) {
        StatisticRange literalRange = literal.isPresent() ? new StatisticRange(literal.getAsDouble(), literal.getAsDouble(), 1.0) : new StatisticRange(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0);
        return ComparisonStatsCalculator.symbolToLiteralRangeComparison(inputStatistics, symbol, literalRange);
    }

    private static PlanNodeStatsEstimate symbolToLiteralNonEquality(PlanNodeStatsEstimate inputStatistics, Symbol symbol, OptionalDouble literal) {
        SymbolStatsEstimate symbolStats = inputStatistics.getSymbolStatistics(symbol);
        StatisticRange range = StatisticRange.from(symbolStats);
        StatisticRange literalRange = literal.isPresent() ? new StatisticRange(literal.getAsDouble(), literal.getAsDouble(), 1.0) : new StatisticRange(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0);
        StatisticRange intersectRange = range.intersect(literalRange);
        double filterFactor = 1.0 - range.overlapPercentWith(intersectRange);
        return inputStatistics.mapOutputRowCount(rowCount -> filterFactor * (1.0 - symbolStats.getNullsFraction()) * rowCount).mapSymbolColumnStatistics(symbol, oldStats -> SymbolStatsEstimate.buildFrom(oldStats).setNullsFraction(0.0).setDistinctValuesCount(Math.max(oldStats.getDistinctValuesCount() - 1.0, 0.0)).build());
    }

    private static PlanNodeStatsEstimate symbolToLiteralLessThan(PlanNodeStatsEstimate inputStatistics, Symbol symbol, OptionalDouble literal) {
        return ComparisonStatsCalculator.symbolToLiteralRangeComparison(inputStatistics, symbol, new StatisticRange(Double.NEGATIVE_INFINITY, literal.orElse(Double.POSITIVE_INFINITY), Double.NaN));
    }

    private static PlanNodeStatsEstimate symbolToLiteralGreaterThan(PlanNodeStatsEstimate inputStatistics, Symbol symbol, OptionalDouble literal) {
        return ComparisonStatsCalculator.symbolToLiteralRangeComparison(inputStatistics, symbol, new StatisticRange(literal.orElse(Double.NEGATIVE_INFINITY), Double.POSITIVE_INFINITY, Double.NaN));
    }

    public static PlanNodeStatsEstimate comparisonSymbolToSymbolStats(PlanNodeStatsEstimate inputStatistics, Symbol left, Symbol right, ComparisonExpressionType type) {
        switch (type) {
            case EQUAL: {
                return ComparisonStatsCalculator.symbolToSymbolEquality(inputStatistics, left, right);
            }
            case NOT_EQUAL: {
                return ComparisonStatsCalculator.symbolToSymbolNonEquality(inputStatistics, left, right);
            }
        }
        return FilterStatsCalculator.filterStatsForUnknownExpression(inputStatistics);
    }

    private static PlanNodeStatsEstimate symbolToSymbolEquality(PlanNodeStatsEstimate inputStatistics, Symbol left, Symbol right) {
        SymbolStatsEstimate leftStats = inputStatistics.getSymbolStatistics(left);
        SymbolStatsEstimate rightStats = inputStatistics.getSymbolStatistics(right);
        if (Double.isNaN(leftStats.getDistinctValuesCount()) || Double.isNaN(rightStats.getDistinctValuesCount())) {
            FilterStatsCalculator.filterStatsForUnknownExpression(inputStatistics);
        }
        StatisticRange leftRange = StatisticRange.from(leftStats);
        StatisticRange rightRange = StatisticRange.from(rightStats);
        StatisticRange intersect = leftRange.intersect(rightRange);
        SymbolStatsEstimate newRightStats = SymbolStatsEstimate.buildFrom(rightStats).setNullsFraction(0.0).setStatisticsRange(intersect).build();
        SymbolStatsEstimate newLeftStats = SymbolStatsEstimate.buildFrom(leftStats).setNullsFraction(0.0).setStatisticsRange(intersect).build();
        double nullsFilterFactor = (1.0 - leftStats.getNullsFraction()) * (1.0 - rightStats.getNullsFraction());
        double filterFactor = 1.0 / Math.max(leftRange.getDistinctValuesCount(), rightRange.getDistinctValuesCount());
        return inputStatistics.mapOutputRowCount(rowCount -> rowCount * filterFactor * nullsFilterFactor).mapSymbolColumnStatistics(left, oldLeftStats -> newLeftStats).mapSymbolColumnStatistics(right, oldRightStats -> newRightStats);
    }

    private static PlanNodeStatsEstimate symbolToSymbolNonEquality(PlanNodeStatsEstimate inputStatistics, Symbol left, Symbol right) {
        return PlanNodeStatsEstimateMath.differenceInStats(inputStatistics, ComparisonStatsCalculator.symbolToSymbolEquality(inputStatistics, left, right));
    }
}

