package org.apache.doris.nereids.stats;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.ColumnStatisticBuilder;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatisticsBuilder;

/* loaded from: input_file:org/apache/doris/nereids/stats/JoinEstimation.class */
public class JoinEstimation {
    private static double DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT = 0.3d;

    private static EqualPredicate normalizeHashJoinCondition(EqualPredicate equalPredicate, Statistics statistics, Statistics statistics2) {
        return equalPredicate.left().getInputSlots().stream().anyMatch(slot -> {
            return statistics2.findColumnStatistics(slot) != null;
        }) ? equalPredicate.commute() : equalPredicate;
    }

    private static boolean hashJoinConditionContainsUnknownColumnStats(Statistics statistics, Statistics statistics2, Join join) {
        Iterator<Expression> it = join.getHashJoinConjuncts().iterator();
        while (it.hasNext()) {
            for (Slot slot : it.next().getInputSlots()) {
                ColumnStatistic findColumnStatistics = statistics.findColumnStatistics(slot);
                if (findColumnStatistics == null) {
                    findColumnStatistics = statistics2.findColumnStatistics(slot);
                }
                if (findColumnStatistics == null || findColumnStatistics.isUnKnown) {
                    return true;
                }
            }
        }
        return false;
    }

    private static Statistics estimateHashJoin(Statistics statistics, Statistics statistics2, Join join) {
        double max;
        ArrayList newArrayList = Lists.newArrayList();
        ArrayList newArrayList2 = Lists.newArrayList();
        boolean z = statistics.getRowCount() > statistics2.getRowCount();
        double nonZeroDivisor = StatsMathUtil.nonZeroDivisor(statistics2.getRowCount());
        double nonZeroDivisor2 = StatsMathUtil.nonZeroDivisor(statistics.getRowCount());
        List list = (List) join.getHashJoinConjuncts().stream().map(expression -> {
            return (EqualPredicate) expression;
        }).filter(equalPredicate -> {
            EqualPredicate normalizeHashJoinCondition = normalizeHashJoinCondition(equalPredicate, statistics, statistics2);
            ColumnStatistic estimate = ExpressionEstimation.estimate(normalizeHashJoinCondition.left(), statistics);
            ColumnStatistic estimate2 = ExpressionEstimation.estimate(normalizeHashJoinCondition.right(), statistics2);
            boolean z2 = estimate2.ndv / nonZeroDivisor > 0.9d || estimate.ndv / nonZeroDivisor2 > 0.9d;
            if (!z2) {
                double nonZeroDivisor3 = StatsMathUtil.nonZeroDivisor(estimate2.ndv);
                double nonZeroDivisor4 = StatsMathUtil.nonZeroDivisor(estimate.ndv);
                if (z) {
                    newArrayList.add(Double.valueOf(((nonZeroDivisor / nonZeroDivisor3) * Math.min(estimate.ndv, estimate2.ndv)) / nonZeroDivisor4));
                } else {
                    newArrayList.add(Double.valueOf(((nonZeroDivisor2 / nonZeroDivisor4) * Math.min(estimate.ndv, estimate2.ndv)) / nonZeroDivisor3));
                }
                newArrayList2.add(normalizeHashJoinCondition);
            }
            return z2;
        }).collect(Collectors.toList());
        Statistics build = new StatisticsBuilder().setRowCount(Math.max(1.0d, statistics.getRowCount()) * Math.max(1.0d, statistics2.getRowCount())).putColumnStatistics(statistics.columnStatistics()).putColumnStatistics(statistics2.columnStatistics()).build();
        if (list.isEmpty()) {
            max = Math.max(statistics.getRowCount(), statistics2.getRowCount());
            Optional min = newArrayList.stream().min((v0, v1) -> {
                return v0.compareTo(v1);
            });
            if (min.isPresent()) {
                max = Math.max(1.0d, max * ((Double) min.get()).doubleValue());
            }
        } else {
            double d = 1.0d;
            double d2 = 1.0d;
            Iterator it = ((List) list.stream().map(equalPredicate2 -> {
                return Double.valueOf(estimateJoinConditionSel(build, equalPredicate2));
            }).sorted().collect(Collectors.toList())).iterator();
            while (it.hasNext()) {
                d *= Math.pow(((Double) it.next()).doubleValue(), 1.0d / d2);
                d2 *= 2.0d;
            }
            max = Math.max(1.0d, build.getRowCount() * d) * Math.pow(0.9d, newArrayList2.size());
        }
        return build.withRowCountAndEnforceValid(max);
    }

    private static Statistics estimateNestLoopJoin(Statistics statistics, Statistics statistics2, Join join) {
        return new StatisticsBuilder().setRowCount(Math.max(1.0d, statistics.getRowCount() * statistics2.getRowCount())).putColumnStatistics(statistics.columnStatistics()).putColumnStatistics(statistics2.columnStatistics()).build();
    }

    private static Statistics estimateInnerJoin(Statistics statistics, Statistics statistics2, Join join) {
        if (hashJoinConditionContainsUnknownColumnStats(statistics, statistics2, join)) {
            return new StatisticsBuilder().setRowCount(Math.max(1.0d, Math.max(statistics.getRowCount(), statistics2.getRowCount()))).putColumnStatistics(statistics.columnStatistics()).putColumnStatistics(statistics2.columnStatistics()).build();
        }
        Statistics estimateNestLoopJoin = join.getHashJoinConjuncts().isEmpty() ? estimateNestLoopJoin(statistics, statistics2, join) : estimateHashJoin(statistics, statistics2, join);
        if (!join.getOtherJoinConjuncts().isEmpty()) {
            estimateNestLoopJoin = new FilterEstimation().estimate(ExpressionUtils.and(join.getOtherJoinConjuncts()), estimateNestLoopJoin);
            if (estimateNestLoopJoin.getRowCount() <= 0.0d) {
                estimateNestLoopJoin = new StatisticsBuilder(estimateNestLoopJoin).setRowCount(1.0d).build();
            }
        }
        return estimateNestLoopJoin;
    }

    private static double estimateJoinConditionSel(Statistics statistics, Expression expression) {
        return new FilterEstimation().estimate(expression, statistics).getRowCount() / statistics.getRowCount();
    }

    private static double estimateSemiOrAntiRowCountBySlotsEqual(Statistics statistics, Statistics statistics2, Join join, EqualPredicate equalPredicate) {
        ColumnStatistic findColumnStatistics;
        double max;
        Expression left = equalPredicate.left();
        Expression right = equalPredicate.right();
        ColumnStatistic findColumnStatistics2 = statistics.findColumnStatistics(left);
        if (findColumnStatistics2 == null) {
            findColumnStatistics2 = statistics.findColumnStatistics(right);
            findColumnStatistics = statistics2.findColumnStatistics(left);
        } else {
            findColumnStatistics = statistics2.findColumnStatistics(right);
        }
        if (findColumnStatistics2 == null || findColumnStatistics == null) {
            return Double.POSITIVE_INFINITY;
        }
        if (join.getJoinType().isLeftSemiOrAntiJoin()) {
            double divide = StatsMathUtil.divide(statistics.getRowCount() * findColumnStatistics.ndv, findColumnStatistics.getOriginalNdv());
            max = join.getJoinType().isSemiJoin() ? divide : Math.max(statistics.getRowCount() - divide, statistics.getRowCount() * DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT);
        } else {
            double divide2 = StatsMathUtil.divide(statistics2.getRowCount() * findColumnStatistics2.ndv, findColumnStatistics2.getOriginalNdv());
            max = join.getJoinType().isSemiJoin() ? divide2 : Math.max(statistics2.getRowCount() - divide2, statistics2.getRowCount() * DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT);
        }
        return Math.max(1.0d, max);
    }

    private static Statistics estimateSemiOrAnti(Statistics statistics, Statistics statistics2, Join join) {
        StatisticsBuilder statisticsBuilder;
        if (hashJoinConditionContainsUnknownColumnStats(statistics, statistics2, join)) {
            return join.getJoinType().isLeftSemiOrAntiJoin() ? new StatisticsBuilder().setRowCount(statistics.getRowCount()).putColumnStatistics(statistics.columnStatistics()).putColumnStatistics(statistics2.columnStatistics()).build() : new StatisticsBuilder().setRowCount(statistics2.getRowCount()).putColumnStatistics(statistics.columnStatistics()).putColumnStatistics(statistics2.columnStatistics()).build();
        }
        double d = Double.POSITIVE_INFINITY;
        Iterator<Expression> it = join.getHashJoinConjuncts().iterator();
        while (it.hasNext()) {
            double estimateSemiOrAntiRowCountBySlotsEqual = estimateSemiOrAntiRowCountBySlotsEqual(statistics, statistics2, join, (EqualPredicate) it.next());
            if (d > estimateSemiOrAntiRowCountBySlotsEqual) {
                d = estimateSemiOrAntiRowCountBySlotsEqual;
            }
        }
        if (Double.isInfinite(d)) {
            Statistics estimateInnerJoin = estimateInnerJoin(statistics, statistics2, join);
            return estimateInnerJoin.withRowCountAndEnforceValid(Math.min(estimateInnerJoin.getRowCount(), join.getJoinType().isLeftSemiOrAntiJoin() ? statistics.getRowCount() : statistics2.getRowCount()));
        }
        if (join.getJoinType().isLeftSemiOrAntiJoin()) {
            statisticsBuilder = new StatisticsBuilder(statistics);
            statisticsBuilder.setRowCount(d);
        } else {
            statisticsBuilder = new StatisticsBuilder(statistics2);
            statisticsBuilder.setRowCount(d);
        }
        Statistics build = statisticsBuilder.build();
        build.enforceValid();
        return build;
    }

    public static Statistics estimate(Statistics statistics, Statistics statistics2, Join join) {
        JoinType joinType = join.getJoinType();
        Statistics build = new StatisticsBuilder().setRowCount(Math.max(1.0d, statistics.getRowCount()) * Math.max(1.0d, statistics2.getRowCount())).putColumnStatistics(statistics.columnStatistics()).putColumnStatistics(statistics2.columnStatistics()).build();
        if (joinType.isSemiOrAntiJoin()) {
            return estimateSemiOrAnti(statistics, statistics2, join);
        }
        if (joinType == JoinType.INNER_JOIN) {
            return updateJoinResultStatsByHashJoinCondition(estimateInnerJoin(statistics, statistics2, join), join);
        }
        if (joinType == JoinType.LEFT_OUTER_JOIN) {
            return build.withRowCountAndEnforceValid(Math.max(statistics.getRowCount(), Math.max(statistics.getRowCount(), estimateInnerJoin(statistics, statistics2, join).getRowCount())));
        }
        if (joinType == JoinType.RIGHT_OUTER_JOIN) {
            return build.withRowCountAndEnforceValid(Math.max(Math.max(statistics2.getRowCount(), estimateInnerJoin(statistics, statistics2, join).getRowCount()), statistics2.getRowCount()));
        }
        if (joinType == JoinType.FULL_OUTER_JOIN) {
            return build.withRowCountAndEnforceValid(statistics.getRowCount() + statistics2.getRowCount() + estimateInnerJoin(statistics, statistics2, join).getRowCount());
        }
        if (joinType == JoinType.CROSS_JOIN) {
            return new StatisticsBuilder().setRowCount(statistics.getRowCount() * statistics2.getRowCount()).putColumnStatistics(statistics.columnStatistics()).putColumnStatistics(statistics2.columnStatistics()).build();
        }
        throw new AnalysisException("join type not supported: " + join.getJoinType());
    }

    private static Statistics updateJoinResultStatsByHashJoinCondition(Statistics statistics, Join join) {
        HashMap hashMap = new HashMap();
        Iterator<Expression> it = join.getHashJoinConjuncts().iterator();
        while (it.hasNext()) {
            EqualPredicate equalPredicate = (EqualPredicate) it.next();
            ColumnStatistic estimate = ExpressionEstimation.estimate(equalPredicate.left(), statistics);
            ColumnStatistic estimate2 = ExpressionEstimation.estimate(equalPredicate.right(), statistics);
            double min = Math.min(estimate.ndv, estimate2.ndv);
            ColumnStatistic build = new ColumnStatisticBuilder(estimate).setNdv(min).build();
            ColumnStatistic build2 = new ColumnStatisticBuilder(estimate2).setNdv(min).build();
            Expression left = equalPredicate.left();
            if (left instanceof Cast) {
                left = left.child(0);
            }
            Expression right = equalPredicate.right();
            if (right instanceof Cast) {
                right = right.child(0);
            }
            hashMap.put(left, build);
            hashMap.put(right, build2);
        }
        hashMap.entrySet().stream().forEach(entry -> {
            statistics.addColumnStats((Expression) entry.getKey(), (ColumnStatistic) entry.getValue());
        });
        return statistics;
    }
}
