package org.apache.doris.nereids.rules.exploration.join;

import com.google.common.collect.ImmutableSet;
import java.util.Set;
import java.util.stream.Stream;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;

/* loaded from: input_file:org/apache/doris/nereids/rules/exploration/join/OuterJoinAssoc.class */
public class OuterJoinAssoc extends OneExplorationRuleFactory {
    public static final OuterJoinAssoc INSTANCE = new OuterJoinAssoc();
    public static Set<Pair<JoinType, JoinType>> VALID_TYPE_PAIR_SET = ImmutableSet.of(Pair.of(JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN), Pair.of(JoinType.LEFT_OUTER_JOIN, JoinType.LEFT_OUTER_JOIN));

    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalJoin(logicalJoin(), group()).when(logicalJoin -> {
            return VALID_TYPE_PAIR_SET.contains(Pair.of(((LogicalJoin) logicalJoin.left()).getJoinType(), logicalJoin.getJoinType()));
        }).when(logicalJoin2 -> {
            return OuterJoinLAsscom.checkReorder(logicalJoin2, (LogicalJoin) logicalJoin2.left());
        }).when(logicalJoin3 -> {
            return checkCondition(logicalJoin3, ((GroupPlan) ((LogicalJoin) logicalJoin3.left()).left()).getOutputSet());
        }).whenNot(logicalJoin4 -> {
            return logicalJoin4.isMarkJoin() || ((LogicalJoin) logicalJoin4.left()).isMarkJoin();
        }).thenApply(matchingContext -> {
            LogicalJoin logicalJoin5 = (LogicalJoin) matchingContext.root;
            LogicalJoin logicalJoin6 = (LogicalJoin) logicalJoin5.left();
            GroupPlan groupPlan = (GroupPlan) logicalJoin6.left();
            GroupPlan groupPlan2 = (GroupPlan) logicalJoin6.right();
            GroupPlan groupPlan3 = (GroupPlan) logicalJoin5.right();
            if (logicalJoin6.getJoinType().isLeftOuterJoin() && logicalJoin5.getJoinType().isLeftOuterJoin() && !logicalJoin5.getConditionSlot().equals(ExpressionUtils.inferNotNullSlots(ImmutableSet.builder().addAll(logicalJoin5.getHashJoinConjuncts()).addAll(logicalJoin5.getOtherJoinConjuncts()).build(), matchingContext.cascadesContext))) {
                return null;
            }
            LogicalJoin<Plan, Plan> withChildrenNoContext = logicalJoin5.withChildrenNoContext(groupPlan2, groupPlan3);
            withChildrenNoContext.getJoinReorderContext().copyFrom(logicalJoin6.getJoinReorderContext());
            LogicalJoin<Plan, Plan> withChildrenNoContext2 = logicalJoin6.withChildrenNoContext(groupPlan, withChildrenNoContext);
            withChildrenNoContext2.getJoinReorderContext().copyFrom(logicalJoin5.getJoinReorderContext());
            setReorderContext(withChildrenNoContext2, withChildrenNoContext);
            return withChildrenNoContext2;
        }).toRule(RuleType.LOGICAL_OUTER_JOIN_ASSOC);
    }

    public static boolean checkCondition(LogicalJoin<? extends Plan, GroupPlan> logicalJoin, Set<Slot> set) {
        return Stream.concat(logicalJoin.getHashJoinConjuncts().stream(), logicalJoin.getOtherJoinConjuncts().stream()).allMatch(expression -> {
            Class<SlotReference> cls = SlotReference.class;
            SlotReference.class.getClass();
            return !Utils.isIntersecting((Set) expression.collect((v1) -> {
                return r1.isInstance(v1);
            }), set);
        });
    }

    public static void setReorderContext(LogicalJoin logicalJoin, LogicalJoin logicalJoin2) {
        logicalJoin2.getJoinReorderContext().setHasCommute(false);
        logicalJoin2.getJoinReorderContext().setHasRightAssociate(false);
        logicalJoin2.getJoinReorderContext().setHasLeftAssociate(false);
        logicalJoin2.getJoinReorderContext().setHasExchange(false);
        logicalJoin.getJoinReorderContext().setHasRightAssociate(true);
        logicalJoin.getJoinReorderContext().setHasCommute(false);
    }
}
