package org.apache.doris.nereids.rules.exploration.join;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.JoinUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/exploration/join/JoinExchange.class */
public class JoinExchange extends OneExplorationRuleFactory {
    public static final JoinExchange INSTANCE = new JoinExchange();

    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return innerLogicalJoin(innerLogicalJoin(), innerLogicalJoin()).when(JoinExchange::checkReorder).whenNot(logicalJoin -> {
            return logicalJoin.hasJoinHint() || ((LogicalJoin) logicalJoin.left()).hasJoinHint() || ((LogicalJoin) logicalJoin.right()).hasJoinHint();
        }).whenNot(logicalJoin2 -> {
            return logicalJoin2.isMarkJoin() || ((LogicalJoin) logicalJoin2.left()).isMarkJoin() || ((LogicalJoin) logicalJoin2.right()).isMarkJoin();
        }).then(logicalJoin3 -> {
            LogicalJoin logicalJoin3 = (LogicalJoin) logicalJoin3.left();
            LogicalJoin logicalJoin4 = (LogicalJoin) logicalJoin3.right();
            GroupPlan groupPlan = (GroupPlan) logicalJoin3.left();
            GroupPlan groupPlan2 = (GroupPlan) logicalJoin3.right();
            GroupPlan groupPlan3 = (GroupPlan) logicalJoin4.left();
            GroupPlan groupPlan4 = (GroupPlan) logicalJoin4.right();
            Set<ExprId> joinOutputExprIdSet = JoinUtils.getJoinOutputExprIdSet(groupPlan, groupPlan3);
            Set<ExprId> joinOutputExprIdSet2 = JoinUtils.getJoinOutputExprIdSet(groupPlan2, groupPlan4);
            ArrayList newArrayList = Lists.newArrayList();
            ArrayList newArrayList2 = Lists.newArrayList();
            ArrayList arrayList = new ArrayList(logicalJoin3.getHashJoinConjuncts());
            arrayList.addAll(logicalJoin4.getHashJoinConjuncts());
            splitTopCondition(logicalJoin3.getHashJoinConjuncts(), joinOutputExprIdSet, joinOutputExprIdSet2, newArrayList, newArrayList2, arrayList);
            ArrayList newArrayList3 = Lists.newArrayList();
            ArrayList newArrayList4 = Lists.newArrayList();
            ArrayList arrayList2 = new ArrayList(logicalJoin3.getOtherJoinConjuncts());
            arrayList2.addAll(logicalJoin4.getOtherJoinConjuncts());
            splitTopCondition(logicalJoin3.getOtherJoinConjuncts(), joinOutputExprIdSet, joinOutputExprIdSet2, newArrayList3, newArrayList4, arrayList2);
            if (newArrayList.size() == 0 || newArrayList2.size() == 0) {
                return null;
            }
            LogicalJoin logicalJoin5 = new LogicalJoin(JoinType.INNER_JOIN, newArrayList, newArrayList3, JoinHint.NONE, groupPlan, groupPlan3);
            LogicalJoin logicalJoin6 = new LogicalJoin(JoinType.INNER_JOIN, newArrayList2, newArrayList4, JoinHint.NONE, groupPlan2, groupPlan4);
            LogicalJoin logicalJoin7 = new LogicalJoin(JoinType.INNER_JOIN, arrayList, arrayList2, JoinHint.NONE, logicalJoin5, logicalJoin6);
            setNewLeftJoinReorder(logicalJoin5, logicalJoin3);
            setNewRightJoinReorder(logicalJoin6, logicalJoin3);
            setNewTopJoinReorder(logicalJoin7, logicalJoin3);
            return logicalJoin7;
        }).toRule(RuleType.LOGICAL_JOIN_EXCHANGE);
    }

    public static boolean checkReorder(LogicalJoin<? extends Plan, ? extends Plan> logicalJoin) {
        return (logicalJoin.getJoinReorderContext().hasCommute() || logicalJoin.getJoinReorderContext().hasLeftAssociate() || logicalJoin.getJoinReorderContext().hasRightAssociate() || logicalJoin.getJoinReorderContext().hasExchange()) ? false : true;
    }

    public static void setNewTopJoinReorder(LogicalJoin logicalJoin, LogicalJoin logicalJoin2) {
        logicalJoin.getJoinReorderContext().copyFrom(logicalJoin2.getJoinReorderContext());
        logicalJoin.getJoinReorderContext().setHasExchange(true);
    }

    public static void setNewLeftJoinReorder(LogicalJoin logicalJoin, LogicalJoin logicalJoin2) {
        logicalJoin.getJoinReorderContext().copyFrom(logicalJoin2.getJoinReorderContext());
        logicalJoin.getJoinReorderContext().setHasCommute(false);
        logicalJoin.getJoinReorderContext().setHasLeftAssociate(false);
        logicalJoin.getJoinReorderContext().setHasRightAssociate(false);
        logicalJoin.getJoinReorderContext().setHasExchange(false);
    }

    public static void setNewRightJoinReorder(LogicalJoin logicalJoin, LogicalJoin logicalJoin2) {
        logicalJoin.getJoinReorderContext().copyFrom(logicalJoin2.getJoinReorderContext());
        logicalJoin.getJoinReorderContext().setHasCommute(false);
        logicalJoin.getJoinReorderContext().setHasLeftAssociate(false);
        logicalJoin.getJoinReorderContext().setHasRightAssociate(false);
        logicalJoin.getJoinReorderContext().setHasExchange(false);
    }

    public static void splitTopCondition(List<Expression> list, Set<ExprId> set, Set<ExprId> set2, List<Expression> list2, List<Expression> list3, List<Expression> list4) {
        for (Expression expression : list) {
            Set<ExprId> inputSlotExprIds = expression.getInputSlotExprIds();
            if (set.containsAll(inputSlotExprIds)) {
                list2.add(expression);
            } else if (set2.containsAll(inputSlotExprIds)) {
                list3.add(expression);
            } else {
                list4.add(expression);
            }
        }
    }
}
