package org.apache.doris.nereids.rules.exploration.join;

import com.google.common.collect.ImmutableList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.CBOUtils;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.Utils;

/* loaded from: input_file:org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.class */
public class InnerJoinLAsscomProject extends OneExplorationRuleFactory {
    public static final InnerJoinLAsscomProject INSTANCE = new InnerJoinLAsscomProject();

    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return innerLogicalJoin(logicalProject(innerLogicalJoin()), group()).when(logicalJoin -> {
            return InnerJoinLAsscom.checkReorder(logicalJoin, (LogicalJoin) ((LogicalProject) logicalJoin.left()).child());
        }).whenNot(logicalJoin2 -> {
            return logicalJoin2.hasJoinHint() || ((LogicalJoin) ((LogicalProject) logicalJoin2.left()).child()).hasJoinHint();
        }).whenNot(logicalJoin3 -> {
            return logicalJoin3.isMarkJoin() || ((LogicalJoin) ((LogicalProject) logicalJoin3.left()).child()).isMarkJoin();
        }).when(logicalJoin4 -> {
            return ((LogicalProject) logicalJoin4.left()).isAllSlots();
        }).then(logicalJoin5 -> {
            LogicalJoin logicalJoin5 = (LogicalJoin) ((LogicalProject) logicalJoin5.left()).child();
            GroupPlan groupPlan = (GroupPlan) logicalJoin5.left();
            GroupPlan groupPlan2 = (GroupPlan) logicalJoin5.right();
            GroupPlan groupPlan3 = (GroupPlan) logicalJoin5.right();
            Set<ExprId> outputExprIdSet = groupPlan2.getOutputExprIdSet();
            Map<Boolean, List<Expression>> splitConjuncts = splitConjuncts(logicalJoin5.getHashJoinConjuncts(), logicalJoin5.getHashJoinConjuncts(), outputExprIdSet);
            List<Expression> list = splitConjuncts.get(true);
            List<Expression> list2 = splitConjuncts.get(false);
            Map<Boolean, List<Expression>> splitConjuncts2 = splitConjuncts(logicalJoin5.getOtherJoinConjuncts(), logicalJoin5.getOtherJoinConjuncts(), outputExprIdSet);
            List<Expression> list3 = splitConjuncts2.get(true);
            List<Expression> list4 = splitConjuncts2.get(false);
            if (list4.isEmpty() && list2.isEmpty()) {
                return null;
            }
            LogicalJoin<Plan, Plan> withConjunctsChildren = logicalJoin5.withConjunctsChildren(list2, list4, groupPlan, groupPlan3);
            withConjunctsChildren.getJoinReorderContext().copyFrom(logicalJoin5.getJoinReorderContext());
            withConjunctsChildren.getJoinReorderContext().setHasLAsscom(false);
            withConjunctsChildren.getJoinReorderContext().setHasCommute(false);
            HashSet hashSet = new HashSet(logicalJoin5.getOutputExprIdSet());
            list.forEach(expression -> {
                hashSet.addAll(expression.getInputSlotExprIds());
            });
            list3.forEach(expression2 -> {
                hashSet.addAll(expression2.getInputSlotExprIds());
            });
            LogicalJoin<Plan, Plan> withConjunctsChildren2 = logicalJoin5.withConjunctsChildren(list, list3, CBOUtils.newProject(hashSet, withConjunctsChildren), CBOUtils.newProject(hashSet, groupPlan2));
            withConjunctsChildren2.getJoinReorderContext().copyFrom(logicalJoin5.getJoinReorderContext());
            withConjunctsChildren2.getJoinReorderContext().setHasLAsscom(true);
            return CBOUtils.projectOrSelf(ImmutableList.copyOf(logicalJoin5.getOutput()), withConjunctsChildren2);
        }).toRule(RuleType.LOGICAL_INNER_JOIN_LASSCOM_PROJECT);
    }

    private Map<Boolean, List<Expression>> splitConjuncts(List<Expression> list, List<Expression> list2, Set<ExprId> set) {
        Map<Boolean, List<Expression>> map = (Map) list.stream().collect(Collectors.partitioningBy(expression -> {
            return Utils.isIntersecting(expression.getInputSlotExprIds(), set);
        }));
        map.get(true).addAll(list2);
        return map;
    }
}
