/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.FilterStatsCalculator;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.SimpleStatsRule;
import com.facebook.presto.cost.StatisticRange;
import com.facebook.presto.cost.StatsNormalizer;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.cost.SymbolStatsEstimate;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.relation.LogicalRowExpressions;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.util.MoreMath;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class JoinStatsRule
extends SimpleStatsRule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join();
    private static final double DEFAULT_UNMATCHED_JOIN_COMPLEMENT_NDVS_COEFFICIENT = 0.5;
    private final FilterStatsCalculator filterStatsCalculator;
    private final StatsNormalizer normalizer;
    private final double unmatchedJoinComplementNdvsCoefficient;

    public JoinStatsRule(FilterStatsCalculator filterStatsCalculator, StatsNormalizer normalizer) {
        this(filterStatsCalculator, normalizer, 0.5);
    }

    @VisibleForTesting
    JoinStatsRule(FilterStatsCalculator filterStatsCalculator, StatsNormalizer normalizer, double unmatchedJoinComplementNdvsCoefficient) {
        super(normalizer);
        this.filterStatsCalculator = Objects.requireNonNull(filterStatsCalculator, "filterStatsCalculator is null");
        this.normalizer = normalizer;
        this.unmatchedJoinComplementNdvsCoefficient = unmatchedJoinComplementNdvsCoefficient;
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    protected Optional<PlanNodeStatsEstimate> doCalculate(JoinNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) {
        PlanNodeStatsEstimate leftStats = sourceStats.getStats(node.getLeft());
        PlanNodeStatsEstimate rightStats = sourceStats.getStats(node.getRight());
        PlanNodeStatsEstimate crossJoinStats = this.crossJoinStats(node, leftStats, rightStats, types);
        switch (node.getType()) {
            case INNER: {
                return Optional.of(this.computeInnerJoinStats(node, crossJoinStats, session, types));
            }
            case LEFT: {
                return Optional.of(this.computeLeftJoinStats(node, leftStats, rightStats, crossJoinStats, session, types));
            }
            case RIGHT: {
                return Optional.of(this.computeRightJoinStats(node, leftStats, rightStats, crossJoinStats, session, types));
            }
            case FULL: {
                return Optional.of(this.computeFullJoinStats(node, leftStats, rightStats, crossJoinStats, session, types));
            }
        }
        throw new IllegalStateException("Unknown join type: " + (Object)((Object)node.getType()));
    }

    private PlanNodeStatsEstimate computeFullJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        PlanNodeStatsEstimate rightJoinComplementStats = this.calculateJoinComplementStats(node.getFilter(), this.flippedCriteria(node), rightStats, leftStats, types);
        return this.addJoinComplementStats(rightStats, this.computeLeftJoinStats(node, leftStats, rightStats, crossJoinStats, session, types), rightJoinComplementStats);
    }

    private PlanNodeStatsEstimate computeLeftJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        PlanNodeStatsEstimate innerJoinStats = this.computeInnerJoinStats(node, crossJoinStats, session, types);
        PlanNodeStatsEstimate leftJoinComplementStats = this.calculateJoinComplementStats(node.getFilter(), node.getCriteria(), leftStats, rightStats, types);
        return this.addJoinComplementStats(leftStats, innerJoinStats, leftJoinComplementStats);
    }

    private PlanNodeStatsEstimate computeRightJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        PlanNodeStatsEstimate innerJoinStats = this.computeInnerJoinStats(node, crossJoinStats, session, types);
        PlanNodeStatsEstimate rightJoinComplementStats = this.calculateJoinComplementStats(node.getFilter(), this.flippedCriteria(node), rightStats, leftStats, types);
        return this.addJoinComplementStats(rightStats, innerJoinStats, rightJoinComplementStats);
    }

    private PlanNodeStatsEstimate computeInnerJoinStats(JoinNode node, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        List<JoinNode.EquiJoinClause> equiJoinCriteria = node.getCriteria();
        if (equiJoinCriteria.isEmpty()) {
            if (!node.getFilter().isPresent()) {
                return crossJoinStats;
            }
            if (OriginalExpressionUtils.isExpression(node.getFilter().get())) {
                return this.filterStatsCalculator.filterStats(crossJoinStats, OriginalExpressionUtils.castToExpression(node.getFilter().get()), session, types);
            }
            return this.filterStatsCalculator.filterStats(crossJoinStats, node.getFilter().get(), session, types);
        }
        PlanNodeStatsEstimate equiJoinEstimate = this.filterByEquiJoinClauses(crossJoinStats, node.getCriteria(), session, types);
        if (equiJoinEstimate.isOutputRowCountUnknown()) {
            return PlanNodeStatsEstimate.unknown();
        }
        if (!node.getFilter().isPresent()) {
            return equiJoinEstimate;
        }
        PlanNodeStatsEstimate filteredEquiJoinEstimate = OriginalExpressionUtils.isExpression(node.getFilter().get()) ? this.filterStatsCalculator.filterStats(equiJoinEstimate, OriginalExpressionUtils.castToExpression(node.getFilter().get()), session, types) : this.filterStatsCalculator.filterStats(equiJoinEstimate, node.getFilter().get(), session, types);
        if (filteredEquiJoinEstimate.isOutputRowCountUnknown()) {
            return this.normalizer.normalize(equiJoinEstimate.mapOutputRowCount(rowCount -> rowCount * 0.9), types);
        }
        return filteredEquiJoinEstimate;
    }

    private PlanNodeStatsEstimate filterByEquiJoinClauses(PlanNodeStatsEstimate stats, Collection<JoinNode.EquiJoinClause> clauses, Session session, TypeProvider types) {
        Preconditions.checkArgument((!clauses.isEmpty() ? 1 : 0) != 0, (Object)"clauses is empty");
        PlanNodeStatsEstimate result = PlanNodeStatsEstimate.unknown();
        LinkedList<JoinNode.EquiJoinClause> remainingClauses = new LinkedList<JoinNode.EquiJoinClause>(clauses);
        JoinNode.EquiJoinClause drivingClause = (JoinNode.EquiJoinClause)remainingClauses.poll();
        for (int i = 0; i < clauses.size(); ++i) {
            PlanNodeStatsEstimate estimate = this.filterByEquiJoinClauses(stats, drivingClause, remainingClauses, session, types);
            if (result.isOutputRowCountUnknown() || !estimate.isOutputRowCountUnknown() && estimate.getOutputRowCount() < result.getOutputRowCount()) {
                result = estimate;
            }
            remainingClauses.add(drivingClause);
            drivingClause = (JoinNode.EquiJoinClause)remainingClauses.poll();
        }
        return result;
    }

    private PlanNodeStatsEstimate filterByEquiJoinClauses(PlanNodeStatsEstimate stats, JoinNode.EquiJoinClause drivingClause, Collection<JoinNode.EquiJoinClause> remainingClauses, Session session, TypeProvider types) {
        ComparisonExpression drivingPredicate = new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)drivingClause.getLeft().toSymbolReference(), (Expression)drivingClause.getRight().toSymbolReference());
        PlanNodeStatsEstimate filteredStats = this.filterStatsCalculator.filterStats(stats, (Expression)drivingPredicate, session, types);
        for (JoinNode.EquiJoinClause clause : remainingClauses) {
            filteredStats = this.filterByAuxiliaryClause(filteredStats, clause, types);
        }
        return filteredStats;
    }

    private PlanNodeStatsEstimate filterByAuxiliaryClause(PlanNodeStatsEstimate stats, JoinNode.EquiJoinClause clause, TypeProvider types) {
        SymbolStatsEstimate leftStats = stats.getSymbolStatistics(clause.getLeft());
        SymbolStatsEstimate rightStats = stats.getSymbolStatistics(clause.getRight());
        StatisticRange leftRange = StatisticRange.from(leftStats);
        StatisticRange rightRange = StatisticRange.from(rightStats);
        StatisticRange intersect = leftRange.intersect(rightRange);
        double leftFilterValue = JoinStatsRule.firstNonNaN(leftRange.overlapPercentWith(intersect), 1.0);
        double rightFilterValue = JoinStatsRule.firstNonNaN(rightRange.overlapPercentWith(intersect), 1.0);
        double leftNdvInRange = leftFilterValue * leftRange.getDistinctValuesCount();
        double rightNdvInRange = rightFilterValue * rightRange.getDistinctValuesCount();
        double retainedNdv = MoreMath.min(leftNdvInRange, rightNdvInRange);
        SymbolStatsEstimate newLeftStats = SymbolStatsEstimate.buildFrom(leftStats).setNullsFraction(0.0).setStatisticsRange(intersect).setDistinctValuesCount(retainedNdv).build();
        SymbolStatsEstimate newRightStats = SymbolStatsEstimate.buildFrom(rightStats).setNullsFraction(0.0).setStatisticsRange(intersect).setDistinctValuesCount(retainedNdv).build();
        PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(stats).setOutputRowCount(stats.getOutputRowCount() * 0.9).addSymbolStatistics(clause.getLeft(), newLeftStats).addSymbolStatistics(clause.getRight(), newRightStats);
        return this.normalizer.normalize(result.build(), types);
    }

    private static double firstNonNaN(double ... values) {
        for (double value : values) {
            if (Double.isNaN(value)) continue;
            return value;
        }
        throw new IllegalArgumentException("All values are NaN");
    }

    @VisibleForTesting
    PlanNodeStatsEstimate calculateJoinComplementStats(Optional<RowExpression> filter, List<JoinNode.EquiJoinClause> criteria, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, TypeProvider types) {
        if (rightStats.getOutputRowCount() == 0.0) {
            return leftStats;
        }
        if (criteria.isEmpty()) {
            if (filter.isPresent()) {
                return PlanNodeStatsEstimate.unknown();
            }
            return this.normalizer.normalize(leftStats.mapOutputRowCount(rowCount -> 0.0), types);
        }
        int numberOfFilterClauses = filter.isPresent() ? (OriginalExpressionUtils.isExpression(filter.get()) ? ExpressionUtils.extractConjuncts(OriginalExpressionUtils.castToExpression(filter.get())).size() : LogicalRowExpressions.extractConjuncts((RowExpression)filter.get()).size()) : 0;
        return criteria.stream().map(drivingClause -> this.calculateJoinComplementStats(leftStats, rightStats, (JoinNode.EquiJoinClause)drivingClause, criteria.size() - 1 + numberOfFilterClauses)).filter(estimate -> !estimate.isOutputRowCountUnknown()).max(Comparator.comparingDouble(PlanNodeStatsEstimate::getOutputRowCount)).map(estimate -> this.normalizer.normalize((PlanNodeStatsEstimate)estimate, types)).orElse(PlanNodeStatsEstimate.unknown());
    }

    private PlanNodeStatsEstimate calculateJoinComplementStats(PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, JoinNode.EquiJoinClause drivingClause, int numberOfRemainingClauses) {
        double matchingRightNDV;
        PlanNodeStatsEstimate result = leftStats;
        SymbolStatsEstimate leftColumnStats = leftStats.getSymbolStatistics(drivingClause.getLeft());
        SymbolStatsEstimate rightColumnStats = rightStats.getSymbolStatistics(drivingClause.getRight());
        double leftNDV = leftColumnStats.getDistinctValuesCount();
        if (leftNDV > (matchingRightNDV = rightColumnStats.getDistinctValuesCount() * this.unmatchedJoinComplementNdvsCoefficient)) {
            double nonMatchingLeftValuesFraction = leftColumnStats.getValuesFraction() * (leftNDV - matchingRightNDV) / leftNDV;
            double scaleFactor = nonMatchingLeftValuesFraction + leftColumnStats.getNullsFraction();
            double newLeftNullsFraction = leftColumnStats.getNullsFraction() / scaleFactor;
            result = result.mapSymbolColumnStatistics(drivingClause.getLeft(), columnStats -> SymbolStatsEstimate.buildFrom(columnStats).setLowValue(leftColumnStats.getLowValue()).setHighValue(leftColumnStats.getHighValue()).setNullsFraction(newLeftNullsFraction).setDistinctValuesCount(leftNDV - matchingRightNDV).build());
            result = result.mapOutputRowCount(rowCount -> rowCount * scaleFactor);
        } else if (leftNDV <= matchingRightNDV) {
            result = result.mapSymbolColumnStatistics(drivingClause.getLeft(), columnStats -> SymbolStatsEstimate.buildFrom(columnStats).setLowValue(Double.NaN).setHighValue(Double.NaN).setNullsFraction(1.0).setDistinctValuesCount(0.0).build());
            result = result.mapOutputRowCount(rowCount -> rowCount * leftColumnStats.getNullsFraction());
        } else {
            return PlanNodeStatsEstimate.unknown();
        }
        result = result.mapOutputRowCount(rowCount -> Math.min(leftStats.getOutputRowCount(), rowCount / Math.pow(0.9, numberOfRemainingClauses)));
        return result;
    }

    @VisibleForTesting
    PlanNodeStatsEstimate addJoinComplementStats(PlanNodeStatsEstimate sourceStats, PlanNodeStatsEstimate innerJoinStats, PlanNodeStatsEstimate joinComplementStats) {
        double innerJoinRowCount = innerJoinStats.getOutputRowCount();
        double joinComplementRowCount = joinComplementStats.getOutputRowCount();
        if (joinComplementRowCount == 0.0) {
            return innerJoinStats;
        }
        double outputRowCount = innerJoinRowCount + joinComplementRowCount;
        PlanNodeStatsEstimate.Builder outputStats = PlanNodeStatsEstimate.buildFrom(innerJoinStats);
        outputStats.setOutputRowCount(outputRowCount);
        for (Symbol symbol : joinComplementStats.getSymbolsWithKnownStatistics()) {
            SymbolStatsEstimate leftSymbolStats = sourceStats.getSymbolStatistics(symbol);
            SymbolStatsEstimate innerJoinSymbolStats = innerJoinStats.getSymbolStatistics(symbol);
            SymbolStatsEstimate joinComplementSymbolStats = joinComplementStats.getSymbolStatistics(symbol);
            double newNullsFraction = (innerJoinSymbolStats.getNullsFraction() * innerJoinRowCount + joinComplementSymbolStats.getNullsFraction() * joinComplementRowCount) / outputRowCount;
            outputStats.addSymbolStatistics(symbol, SymbolStatsEstimate.buildFrom(innerJoinSymbolStats).setLowValue(leftSymbolStats.getLowValue()).setHighValue(leftSymbolStats.getHighValue()).setDistinctValuesCount(leftSymbolStats.getDistinctValuesCount()).setNullsFraction(newNullsFraction).build());
        }
        for (Symbol symbol : Sets.difference(innerJoinStats.getSymbolsWithKnownStatistics(), joinComplementStats.getSymbolsWithKnownStatistics())) {
            SymbolStatsEstimate innerJoinSymbolStats = innerJoinStats.getSymbolStatistics(symbol);
            double newNullsFraction = (innerJoinSymbolStats.getNullsFraction() * innerJoinRowCount + joinComplementRowCount) / outputRowCount;
            outputStats.addSymbolStatistics(symbol, innerJoinSymbolStats.mapNullsFraction(nullsFraction -> newNullsFraction));
        }
        return outputStats.build();
    }

    private PlanNodeStatsEstimate crossJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, TypeProvider types) {
        PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder().setOutputRowCount(leftStats.getOutputRowCount() * rightStats.getOutputRowCount());
        node.getLeft().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics((Symbol)symbol, leftStats.getSymbolStatistics((Symbol)symbol)));
        node.getRight().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics((Symbol)symbol, rightStats.getSymbolStatistics((Symbol)symbol)));
        return this.normalizer.normalize(builder.build(), types);
    }

    private List<JoinNode.EquiJoinClause> flippedCriteria(JoinNode node) {
        return (List)node.getCriteria().stream().map(JoinNode.EquiJoinClause::flip).collect(ImmutableList.toImmutableList());
    }
}

