package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.Set;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.PlanUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughJoin.class */
public class PushdownFilterThroughJoin extends OneRewriteRuleFactory {
    public static final PushdownFilterThroughJoin INSTANCE = new PushdownFilterThroughJoin();
    private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_LEFT = ImmutableList.of(JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN, JoinType.LEFT_SEMI_JOIN, JoinType.LEFT_ANTI_JOIN, JoinType.NULL_AWARE_LEFT_ANTI_JOIN, JoinType.CROSS_JOIN);
    private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_RIGHT = ImmutableList.of(JoinType.INNER_JOIN, JoinType.RIGHT_OUTER_JOIN, JoinType.RIGHT_SEMI_JOIN, JoinType.RIGHT_ANTI_JOIN, JoinType.CROSS_JOIN);
    private static final ImmutableList<JoinType> COULD_PUSH_INSIDE = ImmutableList.of(JoinType.INNER_JOIN);

    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalFilter(logicalJoin()).then(logicalFilter -> {
            LogicalJoin logicalJoin = (LogicalJoin) logicalFilter.child();
            Set<Expression> conjuncts = logicalFilter.getConjuncts();
            ArrayList<Expression> newArrayList = Lists.newArrayList();
            ArrayList newArrayList2 = Lists.newArrayList();
            Set<Slot> outputSet = logicalJoin.left().getOutputSet();
            Set<Slot> outputSet2 = logicalJoin.right().getOutputSet();
            for (Expression expression : conjuncts) {
                if (convertJoinCondition(expression, outputSet, outputSet2, logicalJoin.getJoinType())) {
                    newArrayList2.add(expression);
                } else {
                    newArrayList.add(expression);
                }
            }
            LinkedHashSet newLinkedHashSet = Sets.newLinkedHashSet();
            LinkedHashSet newLinkedHashSet2 = Sets.newLinkedHashSet();
            LinkedHashSet newLinkedHashSet3 = Sets.newLinkedHashSet();
            for (Expression expression2 : newArrayList) {
                Class<SlotReference> cls = SlotReference.class;
                SlotReference.class.getClass();
                Set set = (Set) expression2.collect((v1) -> {
                    return r1.isInstance(v1);
                });
                if (set.isEmpty()) {
                    newLinkedHashSet.add(expression2);
                    newLinkedHashSet2.add(expression2);
                } else if (outputSet.containsAll(set) && COULD_PUSH_THROUGH_LEFT.contains(logicalJoin.getJoinType())) {
                    newLinkedHashSet.add(expression2);
                } else if (outputSet2.containsAll(set) && COULD_PUSH_THROUGH_RIGHT.contains(logicalJoin.getJoinType())) {
                    newLinkedHashSet2.add(expression2);
                } else {
                    newLinkedHashSet3.add(expression2);
                }
            }
            newArrayList2.addAll(logicalJoin.getOtherJoinConjuncts());
            return PlanUtils.filterOrSelf(newLinkedHashSet3, new LogicalJoin(logicalJoin.getJoinType(), logicalJoin.getHashJoinConjuncts(), newArrayList2, logicalJoin.getHint(), logicalJoin.getMarkJoinSlotReference(), PlanUtils.filterOrSelf(newLinkedHashSet, logicalJoin.left()), PlanUtils.filterOrSelf(newLinkedHashSet2, logicalJoin.right())));
        }).toRule(RuleType.PUSHDOWN_FILTER_THROUGH_JOIN);
    }

    private boolean convertJoinCondition(Expression expression, Set<Slot> set, Set<Slot> set2, JoinType joinType) {
        if (!COULD_PUSH_INSIDE.contains(joinType) || !(expression instanceof EqualTo)) {
            return false;
        }
        EqualTo equalTo = (EqualTo) expression;
        Expression left = equalTo.left();
        Class<SlotReference> cls = SlotReference.class;
        SlotReference.class.getClass();
        Set set3 = (Set) left.collect((v1) -> {
            return r1.isInstance(v1);
        });
        Expression right = equalTo.right();
        Class<SlotReference> cls2 = SlotReference.class;
        SlotReference.class.getClass();
        Set set4 = (Set) right.collect((v1) -> {
            return r1.isInstance(v1);
        });
        if (set3.size() == 0 || set4.size() == 0) {
            return false;
        }
        if (set.containsAll(set3) && set2.containsAll(set4)) {
            return true;
        }
        return set.containsAll(set4) && set2.containsAll(set3);
    }
}
