package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Optional;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughWindow.class */
public class PushdownTopNThroughWindow implements RewriteRuleFactory {
    @Override // org.apache.doris.nereids.rules.RuleFactory
    public List<Rule> buildRules() {
        return ImmutableList.of(logicalTopN(logicalWindow()).then(logicalTopN -> {
            LogicalWindow<?> logicalWindow = (LogicalWindow) logicalTopN.child();
            ExprId exprID4WindowFunc = getExprID4WindowFunc(logicalWindow);
            if (exprID4WindowFunc != null && checkTopNForPartitionLimitPushDown(logicalTopN, exprID4WindowFunc)) {
                Optional<Plan> pushPartitionLimitThroughWindow = logicalWindow.pushPartitionLimitThroughWindow(logicalTopN.getLimit() + logicalTopN.getOffset(), true);
                return !pushPartitionLimitThroughWindow.isPresent() ? logicalTopN : (Plan) logicalTopN.withChildren(pushPartitionLimitThroughWindow.get());
            }
            return logicalTopN;
        }).toRule(RuleType.PUSH_TOP_N_THROUGH_WINDOW), logicalTopN(logicalProject(logicalWindow())).then(logicalTopN2 -> {
            LogicalProject logicalProject = (LogicalProject) logicalTopN2.child();
            LogicalWindow<?> logicalWindow = (LogicalWindow) logicalProject.child();
            ExprId exprID4WindowFunc = getExprID4WindowFunc(logicalWindow);
            if (exprID4WindowFunc != null && checkTopNForPartitionLimitPushDown(logicalTopN2, exprID4WindowFunc)) {
                Optional<Plan> pushPartitionLimitThroughWindow = logicalWindow.pushPartitionLimitThroughWindow(logicalTopN2.getLimit() + logicalTopN2.getOffset(), true);
                return !pushPartitionLimitThroughWindow.isPresent() ? logicalTopN2 : (Plan) logicalTopN2.withChildren((Plan) logicalProject.withChildren(pushPartitionLimitThroughWindow.get()));
            }
            return logicalTopN2;
        }).toRule(RuleType.PUSH_TOP_N_THROUGH_PROJECT_WINDOW));
    }

    private ExprId getExprID4WindowFunc(LogicalWindow<?> logicalWindow) {
        List<NamedExpression> windowExpressions = logicalWindow.getWindowExpressions();
        if (windowExpressions.size() != 1) {
            return null;
        }
        NamedExpression namedExpression = windowExpressions.get(0);
        if (namedExpression.children().size() == 1 && (namedExpression.child(0) instanceof WindowExpression)) {
            return namedExpression.getExprId();
        }
        return null;
    }

    private boolean checkTopNForPartitionLimitPushDown(LogicalTopN<?> logicalTopN, ExprId exprId) {
        List<OrderKey> orderKeys = logicalTopN.getOrderKeys();
        if (orderKeys.size() != 1) {
            return false;
        }
        OrderKey orderKey = orderKeys.get(0);
        if (!orderKey.isAsc()) {
            return false;
        }
        Expression expr = orderKey.getExpr();
        return (expr instanceof SlotReference) && ((SlotReference) expr).getExprId() == exprId;
    }
}
