package org.apache.doris.statistics;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.doris.analysis.BinaryPredicate;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.JoinOperator;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.catalog.ColumnStats;
import org.apache.doris.common.CheckedMath;
import org.apache.doris.common.UserException;
import org.apache.doris.planner.HashJoinNode;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:org/apache/doris/statistics/HashJoinStatsDerive.class */
public class HashJoinStatsDerive extends BaseStatsDerive {
    private static final Logger LOG = LogManager.getLogger(HashJoinStatsDerive.class);
    private JoinOperator joinOp;
    private List<BinaryPredicate> eqJoinConjuncts = Lists.newArrayList();

    @Override // org.apache.doris.statistics.BaseStatsDerive
    public void init(PlanStats planStats) throws UserException {
        Preconditions.checkState(planStats instanceof HashJoinNode);
        super.init(planStats);
        this.joinOp = ((HashJoinNode) planStats).getJoinOp();
        this.eqJoinConjuncts.addAll(((HashJoinNode) planStats).getEqJoinConjuncts());
    }

    @Override // org.apache.doris.statistics.BaseStatsDerive
    protected long deriveRowCount() {
        if (this.joinOp.isSemiAntiJoin()) {
            this.rowCount = getSemiJoinrowCount();
        } else if (this.joinOp.isInnerJoin() || this.joinOp.isOuterJoin()) {
            this.rowCount = getJoinrowCount();
        } else if (LOG.isDebugEnabled()) {
            LOG.debug("joinOp:{} is not supported for HashJoinStatsDerive", this.joinOp);
        }
        capRowCountAtLimit();
        return this.rowCount;
    }

    private long getSemiJoinrowCount() {
        double rowCount;
        Preconditions.checkState(this.joinOp.isSemiJoin());
        if (this.joinOp == JoinOperator.RIGHT_SEMI_JOIN || this.joinOp == JoinOperator.RIGHT_ANTI_JOIN) {
            if (this.childrenStatsResult.get(1).getRowCount() == -1.0d) {
                return -1L;
            }
            rowCount = this.childrenStatsResult.get(1).getRowCount();
        } else {
            if (this.childrenStatsResult.get(0).getRowCount() == -1.0d) {
                return -1L;
            }
            rowCount = this.childrenStatsResult.get(0).getRowCount();
        }
        double d = 1.0d;
        for (BinaryPredicate binaryPredicate : this.eqJoinConjuncts) {
            double min = Math.min(getNdv(binaryPredicate.getChild(0)), this.childrenStatsResult.get(0).getRowCount());
            double min2 = Math.min(getNdv(binaryPredicate.getChild(1)), this.childrenStatsResult.get(1).getRowCount());
            if (min != -1.0d && min2 != -1.0d) {
                double d2 = 1.0d;
                switch (this.joinOp) {
                    case LEFT_SEMI_JOIN:
                        d2 = Math.min(min, min2) / min;
                        break;
                    case RIGHT_SEMI_JOIN:
                        d2 = Math.min(min, min2) / min2;
                        break;
                    case LEFT_ANTI_JOIN:
                    case NULL_AWARE_LEFT_ANTI_JOIN:
                        d2 = (min > min2 ? min - min2 : min) / min;
                        break;
                    case RIGHT_ANTI_JOIN:
                        d2 = (min2 > min ? min2 - min : min2) / min2;
                        break;
                    default:
                        Preconditions.checkState(false);
                        break;
                }
                d = Math.min(d, d2);
            }
        }
        Preconditions.checkState(rowCount != -1.0d);
        return Math.round(rowCount * d);
    }

    private long getNdv(Expr expr) {
        SlotDescriptor desc;
        SlotRef unwrapSlotRef = expr.unwrapSlotRef(false);
        if (unwrapSlotRef == null || (desc = unwrapSlotRef.getDesc()) == null) {
            return -1L;
        }
        ColumnStats stats = desc.getStats();
        if (stats.hasNumDistinctValues()) {
            return stats.getNumDistinctValues();
        }
        return -1L;
    }

    private long getJoinrowCount() {
        Preconditions.checkState(this.joinOp.isInnerJoin() || this.joinOp.isOuterJoin());
        Preconditions.checkState(this.childrenStatsResult.size() == 2);
        long rowCount = (long) this.childrenStatsResult.get(0).getRowCount();
        long rowCount2 = (long) this.childrenStatsResult.get(1).getRowCount();
        if (rowCount == -1 || rowCount2 == -1) {
            return rowCount;
        }
        ArrayList arrayList = new ArrayList();
        Iterator<BinaryPredicate> it = this.eqJoinConjuncts.iterator();
        while (it.hasNext()) {
            HashJoinNode.EqJoinConjunctScanSlots create = HashJoinNode.EqJoinConjunctScanSlots.create(it.next());
            if (create != null) {
                arrayList.add(create);
            }
        }
        return arrayList.isEmpty() ? rowCount : getGenericJoinrowCount(arrayList, rowCount, rowCount2);
    }

    private long getGenericJoinrowCount(List<HashJoinNode.EqJoinConjunctScanSlots> list, long j, long j2) {
        Preconditions.checkState(this.joinOp.isInnerJoin() || this.joinOp.isOuterJoin());
        Preconditions.checkState(!list.isEmpty());
        Preconditions.checkState(j >= 0 && j2 >= 0);
        long j3 = -1;
        for (HashJoinNode.EqJoinConjunctScanSlots eqJoinConjunctScanSlots : list) {
            double lhsNdv = eqJoinConjunctScanSlots.lhsNdv();
            if (eqJoinConjunctScanSlots.lhsNumRows() > j) {
                lhsNdv *= j / eqJoinConjunctScanSlots.lhsNumRows();
            }
            double rhsNdv = eqJoinConjunctScanSlots.rhsNdv();
            if (eqJoinConjunctScanSlots.rhsNumRows() > j2) {
                rhsNdv *= j2 / eqJoinConjunctScanSlots.rhsNumRows();
            }
            long checkedMultiply = Double.doubleToLongBits(Math.max(1.0d, Math.max(lhsNdv, rhsNdv))) == j2 ? j : CheckedMath.checkedMultiply(Math.round(j / Math.max(1.0d, Math.max(lhsNdv, rhsNdv))), j2);
            j3 = j3 == -1 ? checkedMultiply : Math.min(j3, checkedMultiply);
        }
        Preconditions.checkState(j3 >= 0);
        return j3;
    }
}
