/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.cost.DisjointRangeDomainHistogram;
import com.facebook.presto.cost.FilterStatsCalculator;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.SimpleStatsRule;
import com.facebook.presto.cost.StatisticRange;
import com.facebook.presto.cost.StatsNormalizer;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.cost.VariableStatsEstimate;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.ExpressionTreeUtils;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.util.MoreMath;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class JoinStatsRule
extends SimpleStatsRule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join();
    private static final double DEFAULT_UNMATCHED_JOIN_COMPLEMENT_NDVS_COEFFICIENT = 0.5;
    private static final double DEFAULT_JOIN_SELECTIVITY_DISABLED = 0.0;
    private final FilterStatsCalculator filterStatsCalculator;
    private final StatsNormalizer normalizer;
    private final double unmatchedJoinComplementNdvsCoefficient;

    public JoinStatsRule(FilterStatsCalculator filterStatsCalculator, StatsNormalizer normalizer) {
        this(filterStatsCalculator, normalizer, 0.5);
    }

    @VisibleForTesting
    JoinStatsRule(FilterStatsCalculator filterStatsCalculator, StatsNormalizer normalizer, double unmatchedJoinComplementNdvsCoefficient) {
        super(normalizer);
        this.filterStatsCalculator = Objects.requireNonNull(filterStatsCalculator, "filterStatsCalculator is null");
        this.normalizer = normalizer;
        this.unmatchedJoinComplementNdvsCoefficient = unmatchedJoinComplementNdvsCoefficient;
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    protected Optional<PlanNodeStatsEstimate> doCalculate(JoinNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) {
        PlanNodeStatsEstimate leftStats = sourceStats.getStats(node.getLeft());
        PlanNodeStatsEstimate rightStats = sourceStats.getStats(node.getRight());
        PlanNodeStatsEstimate crossJoinStats = this.crossJoinStats(node, leftStats, rightStats);
        switch (node.getType()) {
            case INNER: {
                return Optional.of(this.computeInnerJoinStats(node, crossJoinStats, session, types));
            }
            case LEFT: {
                return Optional.of(this.computeLeftJoinStats(node, leftStats, rightStats, crossJoinStats, session, types));
            }
            case RIGHT: {
                return Optional.of(this.computeRightJoinStats(node, leftStats, rightStats, crossJoinStats, session, types));
            }
            case FULL: {
                return Optional.of(this.computeFullJoinStats(node, leftStats, rightStats, crossJoinStats, session, types));
            }
        }
        throw new IllegalStateException("Unknown join type: " + node.getType());
    }

    private PlanNodeStatsEstimate computeFullJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        PlanNodeStatsEstimate rightJoinComplementStats = this.calculateJoinComplementStats(node.getFilter(), this.flippedCriteria(node), rightStats, leftStats);
        return this.addJoinComplementStats(rightStats, this.computeLeftJoinStats(node, leftStats, rightStats, crossJoinStats, session, types), rightJoinComplementStats);
    }

    private PlanNodeStatsEstimate computeLeftJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        PlanNodeStatsEstimate innerJoinStats = this.computeInnerJoinStats(node, crossJoinStats, session, types);
        PlanNodeStatsEstimate leftJoinComplementStats = this.calculateJoinComplementStats(node.getFilter(), node.getCriteria(), leftStats, rightStats);
        return this.addJoinComplementStats(leftStats, innerJoinStats, leftJoinComplementStats);
    }

    private PlanNodeStatsEstimate computeRightJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        PlanNodeStatsEstimate innerJoinStats = this.computeInnerJoinStats(node, crossJoinStats, session, types);
        PlanNodeStatsEstimate rightJoinComplementStats = this.calculateJoinComplementStats(node.getFilter(), this.flippedCriteria(node), rightStats, leftStats);
        return this.addJoinComplementStats(rightStats, innerJoinStats, rightJoinComplementStats);
    }

    private PlanNodeStatsEstimate computeInnerJoinStats(JoinNode node, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        List<EquiJoinClause> equiJoinCriteria = node.getCriteria();
        if (equiJoinCriteria.isEmpty()) {
            if (!node.getFilter().isPresent()) {
                return crossJoinStats;
            }
            return this.filterStatsCalculator.filterStats(crossJoinStats, node.getFilter().get(), session);
        }
        PlanNodeStatsEstimate equiJoinEstimate = this.filterByEquiJoinClauses(crossJoinStats, node.getCriteria(), session, types);
        if (equiJoinEstimate.isOutputRowCountUnknown()) {
            double defaultJoinSelectivityFactor = SystemSessionProperties.getDefaultJoinSelectivityCoefficient(session);
            if (Double.compare(defaultJoinSelectivityFactor, 0.0) != 0) {
                equiJoinEstimate = crossJoinStats.mapOutputRowCount(joinSourceRowCount -> crossJoinStats.getOutputRowCount() * defaultJoinSelectivityFactor);
            } else {
                return PlanNodeStatsEstimate.unknown();
            }
        }
        if (!node.getFilter().isPresent()) {
            return equiJoinEstimate;
        }
        PlanNodeStatsEstimate filteredEquiJoinEstimate = this.filterStatsCalculator.filterStats(equiJoinEstimate, node.getFilter().get(), session);
        if (filteredEquiJoinEstimate.isOutputRowCountUnknown()) {
            return this.normalizer.normalize(equiJoinEstimate.mapOutputRowCount(rowCount -> rowCount * 0.9));
        }
        return filteredEquiJoinEstimate;
    }

    private PlanNodeStatsEstimate filterByEquiJoinClauses(PlanNodeStatsEstimate stats, Collection<EquiJoinClause> clauses, Session session, TypeProvider types) {
        Preconditions.checkArgument((!clauses.isEmpty() ? 1 : 0) != 0, (Object)"clauses is empty");
        PlanNodeStatsEstimate result = PlanNodeStatsEstimate.unknown();
        LinkedList<EquiJoinClause> remainingClauses = new LinkedList<EquiJoinClause>(clauses);
        EquiJoinClause drivingClause = (EquiJoinClause)remainingClauses.poll();
        for (int i = 0; i < clauses.size(); ++i) {
            PlanNodeStatsEstimate estimate = this.filterByEquiJoinClauses(stats, drivingClause, remainingClauses, session, types);
            if (result.isOutputRowCountUnknown() || !estimate.isOutputRowCountUnknown() && estimate.getOutputRowCount() < result.getOutputRowCount()) {
                result = estimate;
            }
            remainingClauses.add(drivingClause);
            drivingClause = (EquiJoinClause)remainingClauses.poll();
        }
        return result;
    }

    private PlanNodeStatsEstimate filterByEquiJoinClauses(PlanNodeStatsEstimate stats, EquiJoinClause drivingClause, Collection<EquiJoinClause> remainingClauses, Session session, TypeProvider types) {
        ComparisonExpression drivingPredicate = new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference(ExpressionTreeUtils.getNodeLocation(drivingClause.getLeft().getSourceLocation()), drivingClause.getLeft().getName()), (Expression)new SymbolReference(ExpressionTreeUtils.getNodeLocation(drivingClause.getRight().getSourceLocation()), drivingClause.getRight().getName()));
        PlanNodeStatsEstimate filteredStats = this.filterStatsCalculator.filterStats(stats, (Expression)drivingPredicate, session, types);
        boolean useHistograms = SystemSessionProperties.shouldOptimizerUseHistograms(session);
        for (EquiJoinClause clause : remainingClauses) {
            filteredStats = this.filterByAuxiliaryClause(filteredStats, clause, useHistograms);
        }
        return filteredStats;
    }

    private PlanNodeStatsEstimate filterByAuxiliaryClause(PlanNodeStatsEstimate stats, EquiJoinClause clause, boolean useHistograms) {
        VariableStatsEstimate leftStats = stats.getVariableStatistics(clause.getLeft());
        VariableStatsEstimate rightStats = stats.getVariableStatistics(clause.getRight());
        StatisticRange leftRange = StatisticRange.from(leftStats);
        StatisticRange rightRange = StatisticRange.from(rightStats);
        StatisticRange intersect = leftRange.intersect(rightRange);
        double leftFilterValue = JoinStatsRule.firstNonNaN(leftRange.overlapPercentWith(intersect), 1.0);
        double rightFilterValue = JoinStatsRule.firstNonNaN(rightRange.overlapPercentWith(intersect), 1.0);
        double leftNdvInRange = leftFilterValue * leftRange.getDistinctValuesCount();
        double rightNdvInRange = rightFilterValue * rightRange.getDistinctValuesCount();
        double retainedNdv = MoreMath.min(leftNdvInRange, rightNdvInRange);
        VariableStatsEstimate.Builder newLeftStats = VariableStatsEstimate.buildFrom(leftStats).setNullsFraction(0.0).setStatisticsRange(intersect).setDistinctValuesCount(retainedNdv);
        if (useHistograms) {
            newLeftStats.setHistogram(leftStats.getHistogram().map(leftHistogram -> DisjointRangeDomainHistogram.addConjunction(leftHistogram, intersect)));
        }
        VariableStatsEstimate.Builder newRightStats = VariableStatsEstimate.buildFrom(rightStats).setNullsFraction(0.0).setStatisticsRange(intersect).setDistinctValuesCount(retainedNdv);
        if (useHistograms) {
            newRightStats.setHistogram(rightStats.getHistogram().map(rightHistogram -> DisjointRangeDomainHistogram.addConjunction(rightHistogram, intersect)));
        }
        PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(stats).setOutputRowCount(stats.getOutputRowCount() * 0.9).addVariableStatistics(clause.getLeft(), newLeftStats.build()).addVariableStatistics(clause.getRight(), newRightStats.build());
        return this.normalizer.normalize(result.build());
    }

    private static double firstNonNaN(double ... values) {
        for (double value : values) {
            if (Double.isNaN(value)) continue;
            return value;
        }
        throw new IllegalArgumentException("All values are NaN");
    }

    @VisibleForTesting
    PlanNodeStatsEstimate calculateJoinComplementStats(Optional<RowExpression> filter, List<EquiJoinClause> criteria, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats) {
        if (rightStats.getOutputRowCount() == 0.0) {
            return leftStats;
        }
        if (criteria.isEmpty()) {
            if (filter.isPresent()) {
                return PlanNodeStatsEstimate.unknown();
            }
            return this.normalizer.normalize(leftStats.mapOutputRowCount(rowCount -> 0.0));
        }
        int numberOfFilterClauses = filter.map(expression -> LogicalRowExpressions.extractConjuncts((RowExpression)expression).size()).orElse(0);
        return criteria.stream().map(drivingClause -> this.calculateJoinComplementStats(leftStats, rightStats, (EquiJoinClause)drivingClause, criteria.size() - 1 + numberOfFilterClauses)).filter(estimate -> !estimate.isOutputRowCountUnknown()).max(Comparator.comparingDouble(PlanNodeStatsEstimate::getOutputRowCount)).map(estimate -> this.normalizer.normalize((PlanNodeStatsEstimate)estimate)).orElse(PlanNodeStatsEstimate.unknown());
    }

    private PlanNodeStatsEstimate calculateJoinComplementStats(PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, EquiJoinClause drivingClause, int numberOfRemainingClauses) {
        double matchingRightNDV;
        PlanNodeStatsEstimate result = leftStats;
        VariableStatsEstimate leftColumnStats = leftStats.getVariableStatistics(drivingClause.getLeft());
        VariableStatsEstimate rightColumnStats = rightStats.getVariableStatistics(drivingClause.getRight());
        double leftNDV = leftColumnStats.getDistinctValuesCount();
        if (leftNDV > (matchingRightNDV = rightColumnStats.getDistinctValuesCount() * this.unmatchedJoinComplementNdvsCoefficient)) {
            double nonMatchingLeftValuesFraction = leftColumnStats.getValuesFraction() * (leftNDV - matchingRightNDV) / leftNDV;
            double scaleFactor = nonMatchingLeftValuesFraction + leftColumnStats.getNullsFraction();
            double newLeftNullsFraction = leftColumnStats.getNullsFraction() / scaleFactor;
            result = result.mapVariableColumnStatistics(drivingClause.getLeft(), columnStats -> VariableStatsEstimate.buildFrom(columnStats).setLowValue(leftColumnStats.getLowValue()).setHighValue(leftColumnStats.getHighValue()).setNullsFraction(newLeftNullsFraction).setDistinctValuesCount(leftNDV - matchingRightNDV).build());
            result = result.mapOutputRowCount(rowCount -> rowCount * scaleFactor);
        } else if (leftNDV <= matchingRightNDV) {
            result = result.mapVariableColumnStatistics(drivingClause.getLeft(), columnStats -> VariableStatsEstimate.buildFrom(columnStats).setLowValue(Double.NaN).setHighValue(Double.NaN).setNullsFraction(1.0).setDistinctValuesCount(0.0).build());
            result = result.mapOutputRowCount(rowCount -> rowCount * leftColumnStats.getNullsFraction());
        } else {
            return PlanNodeStatsEstimate.unknown();
        }
        result = result.mapOutputRowCount(rowCount -> Math.min(leftStats.getOutputRowCount(), rowCount / Math.pow(0.9, numberOfRemainingClauses)));
        return result;
    }

    @VisibleForTesting
    PlanNodeStatsEstimate addJoinComplementStats(PlanNodeStatsEstimate sourceStats, PlanNodeStatsEstimate innerJoinStats, PlanNodeStatsEstimate joinComplementStats) {
        double innerJoinRowCount = innerJoinStats.getOutputRowCount();
        double joinComplementRowCount = joinComplementStats.getOutputRowCount();
        if (joinComplementRowCount == 0.0 || joinComplementStats.equals(PlanNodeStatsEstimate.unknown())) {
            return innerJoinStats;
        }
        double outputRowCount = innerJoinRowCount + joinComplementRowCount;
        PlanNodeStatsEstimate.Builder outputStats = PlanNodeStatsEstimate.buildFrom(innerJoinStats);
        outputStats.setOutputRowCount(outputRowCount);
        for (VariableReferenceExpression variable : joinComplementStats.getVariablesWithKnownStatistics()) {
            VariableStatsEstimate leftSymbolStats = sourceStats.getVariableStatistics(variable);
            VariableStatsEstimate innerJoinSymbolStats = innerJoinStats.getVariableStatistics(variable);
            VariableStatsEstimate joinComplementSymbolStats = joinComplementStats.getVariableStatistics(variable);
            double newNullsFraction = (innerJoinSymbolStats.getNullsFraction() * innerJoinRowCount + joinComplementSymbolStats.getNullsFraction() * joinComplementRowCount) / outputRowCount;
            outputStats.addVariableStatistics(variable, VariableStatsEstimate.buildFrom(innerJoinSymbolStats).setLowValue(leftSymbolStats.getLowValue()).setHighValue(leftSymbolStats.getHighValue()).setDistinctValuesCount(leftSymbolStats.getDistinctValuesCount()).setNullsFraction(newNullsFraction).build());
        }
        for (VariableReferenceExpression variable : Sets.difference(innerJoinStats.getVariablesWithKnownStatistics(), joinComplementStats.getVariablesWithKnownStatistics())) {
            VariableStatsEstimate innerJoinSymbolStats = innerJoinStats.getVariableStatistics(variable);
            double newNullsFraction = (innerJoinSymbolStats.getNullsFraction() * innerJoinRowCount + joinComplementRowCount) / outputRowCount;
            outputStats.addVariableStatistics(variable, innerJoinSymbolStats.mapNullsFraction(nullsFraction -> newNullsFraction));
        }
        return outputStats.build();
    }

    private PlanNodeStatsEstimate crossJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats) {
        PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder().setOutputRowCount(leftStats.getOutputRowCount() * rightStats.getOutputRowCount());
        node.getLeft().getOutputVariables().forEach(variable -> builder.addVariableStatistics((VariableReferenceExpression)variable, leftStats.getVariableStatistics((VariableReferenceExpression)variable)));
        node.getRight().getOutputVariables().forEach(variable -> builder.addVariableStatistics((VariableReferenceExpression)variable, rightStats.getVariableStatistics((VariableReferenceExpression)variable)));
        return this.normalizer.normalize(builder.build());
    }

    private List<EquiJoinClause> flippedCriteria(JoinNode node) {
        return (List)node.getCriteria().stream().map(EquiJoinClause::flip).collect(ImmutableList.toImmutableList());
    }
}

