package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoin.class */
public class PushdownTopNThroughJoin implements RewriteRuleFactory {
    @Override // org.apache.doris.nereids.rules.RuleFactory
    public List<Rule> buildRules() {
        return ImmutableList.of(logicalTopN(logicalJoin()).when(logicalTopN -> {
            Stream<R> map = logicalTopN.getOrderKeys().stream().map((v0) -> {
                return v0.getExpr();
            });
            Class<Slot> cls = Slot.class;
            Slot.class.getClass();
            return map.allMatch((v1) -> {
                return r1.isInstance(v1);
            });
        }).then(logicalTopN2 -> {
            Plan pushLimitThroughJoin = pushLimitThroughJoin(logicalTopN2, (LogicalJoin) logicalTopN2.child());
            if (pushLimitThroughJoin == null || ((LogicalJoin) logicalTopN2.child()).children().equals(pushLimitThroughJoin.children())) {
                return null;
            }
            return (Plan) logicalTopN2.withChildren(pushLimitThroughJoin);
        }).toRule(RuleType.PUSH_TOP_N_THROUGH_JOIN), logicalTopN(logicalProject(logicalJoin())).when(logicalTopN3 -> {
            Stream<R> map = logicalTopN3.getOrderKeys().stream().map((v0) -> {
                return v0.getExpr();
            });
            Class<Slot> cls = Slot.class;
            Slot.class.getClass();
            return map.allMatch((v1) -> {
                return r1.isInstance(v1);
            });
        }).then(logicalTopN4 -> {
            Plan pushLimitThroughJoin;
            LogicalProject logicalProject = (LogicalProject) logicalTopN4.child();
            LogicalJoin<Plan, Plan> logicalJoin = (LogicalJoin) logicalProject.child();
            Set<Slot> outputSet = ((LogicalJoin) logicalProject.child()).getOutputSet();
            Stream flatMap = logicalTopN4.getOrderKeys().stream().map((v0) -> {
                return v0.getExpr();
            }).flatMap(expression -> {
                return expression.getInputSlots().stream();
            });
            outputSet.getClass();
            if (!flatMap.allMatch((v1) -> {
                return r1.contains(v1);
            }) || (pushLimitThroughJoin = pushLimitThroughJoin(logicalTopN4, logicalJoin)) == null || logicalJoin.children().equals(pushLimitThroughJoin.children())) {
                return null;
            }
            return (Plan) logicalTopN4.withChildren((Plan) logicalProject.withChildren(pushLimitThroughJoin));
        }).toRule(RuleType.PUSH_TOP_N_THROUGH_PROJECT_JOIN));
    }

    private Plan pushLimitThroughJoin(LogicalTopN<? extends Plan> logicalTopN, LogicalJoin<Plan, Plan> logicalJoin) {
        List list = (List) logicalTopN.getOrderKeys().stream().map((v0) -> {
            return v0.getExpr();
        }).flatMap(expression -> {
            return expression.getInputSlots().stream();
        }).collect(Collectors.toList());
        switch (logicalJoin.getJoinType()) {
            case LEFT_OUTER_JOIN:
                if (logicalJoin.left().getOutputSet().containsAll(list)) {
                    return (Plan) logicalJoin.withChildren(logicalTopN.withLimitChild(logicalTopN.getLimit() + logicalTopN.getOffset(), 0L, logicalJoin.left()), logicalJoin.right());
                }
                return null;
            case RIGHT_OUTER_JOIN:
                if (logicalJoin.right().getOutputSet().containsAll(list)) {
                    return (Plan) logicalJoin.withChildren(logicalJoin.left(), logicalTopN.withLimitChild(logicalTopN.getLimit() + logicalTopN.getOffset(), 0L, logicalJoin.right()));
                }
                return null;
            case CROSS_JOIN:
                if (logicalJoin.left().getOutputSet().containsAll(list)) {
                    return (Plan) logicalJoin.withChildren(logicalTopN.withLimitChild(logicalTopN.getLimit() + logicalTopN.getOffset(), 0L, logicalJoin.left()), logicalJoin.right());
                }
                if (logicalJoin.right().getOutputSet().containsAll(list)) {
                    return (Plan) logicalJoin.withChildren(logicalJoin.left(), logicalTopN.withLimitChild(logicalTopN.getLimit() + logicalTopN.getOffset(), 0L, logicalJoin.right()));
                }
                return null;
            default:
                return null;
        }
    }
}
