package org.apache.doris.nereids.rules.exploration.join;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.CBOUtils;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.JoinUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/exploration/join/JoinExchangeBothProject.class */
public class JoinExchangeBothProject extends OneExplorationRuleFactory {
    public static final JoinExchangeBothProject INSTANCE = new JoinExchangeBothProject();

    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return innerLogicalJoin(logicalProject(innerLogicalJoin()), logicalProject(innerLogicalJoin())).when(JoinExchange::checkReorder).when(logicalJoin -> {
            return ((LogicalProject) logicalJoin.left()).isAllSlots() && ((LogicalProject) logicalJoin.right()).isAllSlots();
        }).whenNot(logicalJoin2 -> {
            return logicalJoin2.hasJoinHint() || ((LogicalJoin) ((LogicalProject) logicalJoin2.left()).child()).hasJoinHint() || ((LogicalJoin) ((LogicalProject) logicalJoin2.right()).child()).hasJoinHint();
        }).whenNot(logicalJoin3 -> {
            return logicalJoin3.isMarkJoin() || ((LogicalJoin) ((LogicalProject) logicalJoin3.left()).child()).isMarkJoin() || ((LogicalJoin) ((LogicalProject) logicalJoin3.right()).child()).isMarkJoin();
        }).then(logicalJoin4 -> {
            LogicalJoin logicalJoin4 = (LogicalJoin) ((LogicalProject) logicalJoin4.left()).child();
            LogicalJoin logicalJoin5 = (LogicalJoin) ((LogicalProject) logicalJoin4.right()).child();
            GroupPlan groupPlan = (GroupPlan) logicalJoin4.left();
            GroupPlan groupPlan2 = (GroupPlan) logicalJoin4.right();
            GroupPlan groupPlan3 = (GroupPlan) logicalJoin5.left();
            GroupPlan groupPlan4 = (GroupPlan) logicalJoin5.right();
            Set<ExprId> joinOutputExprIdSet = JoinUtils.getJoinOutputExprIdSet(groupPlan, groupPlan3);
            Set<ExprId> joinOutputExprIdSet2 = JoinUtils.getJoinOutputExprIdSet(groupPlan2, groupPlan4);
            ArrayList newArrayList = Lists.newArrayList();
            ArrayList newArrayList2 = Lists.newArrayList();
            ArrayList arrayList = new ArrayList(logicalJoin4.getHashJoinConjuncts());
            arrayList.addAll(logicalJoin5.getHashJoinConjuncts());
            JoinExchange.splitTopCondition(logicalJoin4.getHashJoinConjuncts(), joinOutputExprIdSet, joinOutputExprIdSet2, newArrayList, newArrayList2, arrayList);
            ArrayList newArrayList3 = Lists.newArrayList();
            ArrayList newArrayList4 = Lists.newArrayList();
            ArrayList arrayList2 = new ArrayList(logicalJoin4.getOtherJoinConjuncts());
            arrayList2.addAll(logicalJoin5.getOtherJoinConjuncts());
            JoinExchange.splitTopCondition(logicalJoin4.getOtherJoinConjuncts(), joinOutputExprIdSet, joinOutputExprIdSet2, newArrayList3, newArrayList4, arrayList2);
            if (newArrayList.size() == 0 || newArrayList2.size() == 0) {
                return null;
            }
            LogicalJoin logicalJoin6 = new LogicalJoin(JoinType.INNER_JOIN, newArrayList, newArrayList3, JoinHint.NONE, groupPlan, groupPlan3);
            LogicalJoin logicalJoin7 = new LogicalJoin(JoinType.INNER_JOIN, newArrayList2, newArrayList4, JoinHint.NONE, groupPlan2, groupPlan4);
            HashSet hashSet = new HashSet(logicalJoin4.getOutputExprIdSet());
            arrayList.forEach(expression -> {
                hashSet.addAll(expression.getInputSlotExprIds());
            });
            arrayList2.forEach(expression2 -> {
                hashSet.addAll(expression2.getInputSlotExprIds());
            });
            LogicalJoin logicalJoin8 = new LogicalJoin(JoinType.INNER_JOIN, arrayList, arrayList2, JoinHint.NONE, CBOUtils.newProject(hashSet, logicalJoin6), CBOUtils.newProject(hashSet, logicalJoin7));
            JoinExchange.setNewLeftJoinReorder(logicalJoin6, logicalJoin4);
            JoinExchange.setNewRightJoinReorder(logicalJoin7, logicalJoin4);
            JoinExchange.setNewTopJoinReorder(logicalJoin8, logicalJoin4);
            return CBOUtils.projectOrSelf(ImmutableList.copyOf(logicalJoin4.getOutput()), logicalJoin8);
        }).toRule(RuleType.LOGICAL_JOIN_EXCHANGE_BOTH_PROJECT);
    }

    public static boolean checkReorder(LogicalJoin<? extends Plan, ? extends Plan> logicalJoin) {
        return (logicalJoin.getJoinReorderContext().hasCommute() || logicalJoin.getJoinReorderContext().hasLeftAssociate() || logicalJoin.getJoinReorderContext().hasRightAssociate() || logicalJoin.getJoinReorderContext().hasExchange()) ? false : true;
    }
}
