package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.util.PlanUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughRepeat.class */
public class PushdownFilterThroughRepeat extends OneRewriteRuleFactory {
    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalFilter(logicalRepeat()).then(logicalFilter -> {
            LogicalRepeat logicalRepeat = (LogicalRepeat) logicalFilter.child();
            Set<Expression> commonGroupingSetExpressions = logicalRepeat.getCommonGroupingSetExpressions();
            if (commonGroupingSetExpressions.isEmpty()) {
                return logicalFilter;
            }
            HashSet newHashSet = Sets.newHashSet();
            HashSet newHashSet2 = Sets.newHashSet();
            for (Expression expression : logicalFilter.getConjuncts()) {
                if (commonGroupingSetExpressions.containsAll(expression.getInputSlots())) {
                    newHashSet.add(expression);
                } else {
                    newHashSet2.add(expression);
                }
            }
            return pushDownPredicate(logicalFilter, logicalRepeat, newHashSet, newHashSet2);
        }).toRule(RuleType.PUSHDOWN_PREDICATE_THROUGH_REPEAT);
    }

    private Plan pushDownPredicate(LogicalFilter logicalFilter, LogicalRepeat logicalRepeat, Set<Expression> set, Set<Expression> set2) {
        return set.size() == 0 ? logicalFilter : PlanUtils.filterOrSelf(set2, logicalRepeat.withChildren2((List<Plan>) ImmutableList.of(new LogicalFilter(set, logicalRepeat.child(0)))));
    }
}
