package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/CollectProjectAboveConsumer.class */
public class CollectProjectAboveConsumer implements RewriteRuleFactory {
    @Override // org.apache.doris.nereids.rules.RuleFactory
    public List<Rule> buildRules() {
        return ImmutableList.of(RuleType.COLLECT_PROJECT_ABOVE_CONSUMER.build(logicalProject(logicalCTEConsumer()).thenApply(matchingContext -> {
            LogicalProject logicalProject = (LogicalProject) matchingContext.root;
            collectProject(matchingContext.cascadesContext, logicalProject.getProjects(), (LogicalCTEConsumer) logicalProject.child());
            return (LogicalProject) matchingContext.root;
        })), RuleType.COLLECT_PROJECT_ABOVE_FILTER_CONSUMER.build(logicalProject(logicalFilter(logicalCTEConsumer())).thenApply(matchingContext2 -> {
            LogicalProject logicalProject = (LogicalProject) matchingContext2.root;
            LogicalFilter logicalFilter = (LogicalFilter) logicalProject.child();
            Set<Slot> inputSlots = logicalFilter.getInputSlots();
            ArrayList arrayList = new ArrayList(logicalProject.getProjects());
            for (Slot slot : inputSlots) {
                if (!logicalProject.getOutput().contains(slot)) {
                    arrayList.add(slot);
                }
            }
            collectProject(matchingContext2.cascadesContext, arrayList, (LogicalCTEConsumer) logicalFilter.child());
            return (LogicalProject) matchingContext2.root;
        })));
    }

    private static void collectProject(CascadesContext cascadesContext, List<NamedExpression> list, LogicalCTEConsumer logicalCTEConsumer) {
        Iterator<NamedExpression> it = list.iterator();
        while (it.hasNext()) {
            it.next().foreach(treeNode -> {
                if (treeNode instanceof Slot) {
                    cascadesContext.putCTEIdToProject(logicalCTEConsumer.getCteId(), logicalCTEConsumer.getProducerSlot((Slot) treeNode));
                    cascadesContext.markConsumerUnderProject(logicalCTEConsumer);
                }
            });
        }
    }
}
