package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/PushdownLimitDistinctThroughJoin.class */
public class PushdownLimitDistinctThroughJoin implements RewriteRuleFactory {
    @Override // org.apache.doris.nereids.rules.RuleFactory
    public List<Rule> buildRules() {
        return ImmutableList.of(logicalLimit(logicalAggregate(logicalJoin()).when((v0) -> {
            return v0.isDistinct();
        })).then(logicalLimit -> {
            LogicalAggregate logicalAggregate = (LogicalAggregate) logicalLimit.child();
            LogicalJoin<Plan, Plan> logicalJoin = (LogicalJoin) logicalAggregate.child();
            Plan pushLimitThroughJoin = pushLimitThroughJoin(logicalLimit, logicalJoin);
            if (pushLimitThroughJoin == null || logicalJoin.children().equals(pushLimitThroughJoin.children())) {
                return null;
            }
            return (Plan) logicalLimit.withChildren((Plan) logicalAggregate.withChildren(pushLimitThroughJoin));
        }).toRule(RuleType.PUSH_LIMIT_DISTINCT_THROUGH_JOIN), logicalLimit(logicalAggregate(logicalProject(logicalJoin()).when((v0) -> {
            return v0.isAllSlots();
        })).when((v0) -> {
            return v0.isDistinct();
        })).then(logicalLimit2 -> {
            LogicalAggregate logicalAggregate = (LogicalAggregate) logicalLimit2.child();
            LogicalProject logicalProject = (LogicalProject) logicalAggregate.child();
            LogicalJoin<Plan, Plan> logicalJoin = (LogicalJoin) logicalProject.child();
            Plan pushLimitThroughJoin = pushLimitThroughJoin(logicalLimit2, logicalJoin);
            if (pushLimitThroughJoin == null || logicalJoin.children().equals(pushLimitThroughJoin.children())) {
                return null;
            }
            return (Plan) logicalLimit2.withChildren((Plan) logicalAggregate.withChildren((Plan) logicalProject.withChildren(pushLimitThroughJoin)));
        }).toRule(RuleType.PUSH_LIMIT_DISTINCT_THROUGH_JOIN));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Plan pushLimitThroughJoin(LogicalLimit<?> logicalLimit, LogicalJoin<Plan, Plan> logicalJoin) {
        LogicalAggregate logicalAggregate = (LogicalAggregate) logicalLimit.child();
        List list = (List) logicalAggregate.getGroupByExpressions().stream().flatMap(expression -> {
            return expression.getInputSlots().stream();
        }).collect(Collectors.toList());
        switch (logicalJoin.getJoinType()) {
            case LEFT_OUTER_JOIN:
                if (logicalJoin.left().getOutputSet().containsAll(list) && logicalJoin.left().getOutputSet().equals(logicalAggregate.getOutputSet())) {
                    return (Plan) logicalJoin.withChildren(logicalLimit.withLimitChild(logicalLimit.getLimit() + logicalLimit.getOffset(), 0L, (Plan) logicalAggregate.withChildren(logicalJoin.left())), logicalJoin.right());
                }
                return null;
            case RIGHT_OUTER_JOIN:
                if (logicalJoin.right().getOutputSet().containsAll(list) && logicalJoin.right().getOutputSet().equals(logicalAggregate.getOutputSet())) {
                    return (Plan) logicalJoin.withChildren(logicalJoin.left(), logicalLimit.withLimitChild(logicalLimit.getLimit() + logicalLimit.getOffset(), 0L, (Plan) logicalAggregate.withChildren(logicalJoin.right())));
                }
                return null;
            case CROSS_JOIN:
                if (logicalJoin.left().getOutputSet().containsAll(list) && logicalJoin.left().getOutputSet().equals(logicalAggregate.getOutputSet())) {
                    return (Plan) logicalJoin.withChildren(logicalLimit.withLimitChild(logicalLimit.getLimit() + logicalLimit.getOffset(), 0L, (Plan) logicalAggregate.withChildren(logicalJoin.left())), logicalJoin.right());
                }
                if (logicalJoin.right().getOutputSet().containsAll(list) && logicalJoin.right().getOutputSet().equals(logicalAggregate.getOutputSet())) {
                    return (Plan) logicalJoin.withChildren(logicalJoin.left(), logicalLimit.withLimitChild(logicalLimit.getLimit() + logicalLimit.getOffset(), 0L, (Plan) logicalAggregate.withChildren(logicalJoin.right())));
                }
                return null;
            default:
                return null;
        }
    }
}
