package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRows.class */
public class EliminateAssertNumRows extends OneRewriteRuleFactory {
    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalAssertNumRows().then(logicalAssertNumRows -> {
            Plan plan;
            Plan plan2 = (Plan) logicalAssertNumRows.child();
            while (true) {
                plan = plan2;
                if (skipPlan(plan) == plan) {
                    break;
                }
                plan2 = skipPlan(plan);
            }
            if (canEliminate(logicalAssertNumRows, plan)) {
                return (Plan) logicalAssertNumRows.child();
            }
            return null;
        }).toRule(RuleType.ELIMINATE_ASSERT_NUM_ROWS);
    }

    private Plan skipPlan(Plan plan) {
        if ((plan instanceof LogicalProject) || (plan instanceof LogicalFilter) || (plan instanceof LogicalSort)) {
            plan = plan.child(0);
        } else if (plan instanceof LogicalJoin) {
            if (((LogicalJoin) plan).getJoinType().isLeftSemiOrAntiJoin()) {
                plan = plan.child(0);
            } else if (((LogicalJoin) plan).getJoinType().isRightSemiOrAntiJoin()) {
                plan = plan.child(1);
            }
        }
        return plan;
    }

    private boolean canEliminate(LogicalAssertNumRows<?> logicalAssertNumRows, Plan plan) {
        long j;
        if (plan instanceof LogicalLimit) {
            j = ((LogicalLimit) plan).getLimit();
        } else {
            if (!(plan instanceof LogicalAggregate) || !((LogicalAggregate) plan).getGroupByExpressions().isEmpty()) {
                return false;
            }
            j = 1;
        }
        AssertNumRowsElement assertNumRowsElement = logicalAssertNumRows.getAssertNumRowsElement();
        AssertNumRowsElement.Assertion assertion = assertNumRowsElement.getAssertion();
        long desiredNumOfRows = assertNumRowsElement.getDesiredNumOfRows();
        switch (assertion) {
            case NE:
            case LT:
                return j < desiredNumOfRows;
            case LE:
                return j <= desiredNumOfRows;
            default:
                return false;
        }
    }
}
