package org.apache.doris.nereids.rules.exploration.join;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.HashSet;
import java.util.Set;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.CBOUtils;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

/* loaded from: input_file:org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.class */
public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory {
    public static final SemiJoinSemiJoinTransposeProject INSTANCE = new SemiJoinSemiJoinTransposeProject();

    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalJoin(logicalProject(logicalJoin()), group()).when(this::typeChecker).when(logicalJoin -> {
            return InnerJoinLAsscom.checkReorder(logicalJoin, (LogicalJoin) ((LogicalProject) logicalJoin.left()).child());
        }).whenNot(logicalJoin2 -> {
            return logicalJoin2.hasJoinHint() || ((LogicalJoin) ((LogicalProject) logicalJoin2.left()).child()).hasJoinHint();
        }).whenNot(logicalJoin3 -> {
            return logicalJoin3.isMarkJoin() || ((LogicalJoin) ((LogicalProject) logicalJoin3.left()).child()).isMarkJoin();
        }).when(logicalJoin4 -> {
            return ((LogicalProject) logicalJoin4.left()).isAllSlots();
        }).then(logicalJoin5 -> {
            LogicalJoin logicalJoin5 = (LogicalJoin) ((LogicalProject) logicalJoin5.left()).child();
            LogicalProject logicalProject = (LogicalProject) logicalJoin5.left();
            GroupPlan groupPlan = (GroupPlan) logicalJoin5.left();
            GroupPlan groupPlan2 = (GroupPlan) logicalJoin5.right();
            GroupPlan groupPlan3 = (GroupPlan) logicalJoin5.right();
            Set<ExprId> outputExprIdSet = groupPlan.getOutputExprIdSet();
            HashSet hashSet = new HashSet(logicalProject.getProjects());
            logicalJoin5.getConditionSlot().forEach(slot -> {
                if (outputExprIdSet.contains(slot.getExprId())) {
                    hashSet.add(slot);
                }
            });
            LogicalJoin<Plan, Plan> withChildrenNoContext = logicalJoin5.withChildrenNoContext(groupPlan, groupPlan3);
            withChildrenNoContext.getJoinReorderContext().copyFrom(logicalJoin5.getJoinReorderContext());
            withChildrenNoContext.getJoinReorderContext().setHasCommute(false);
            withChildrenNoContext.getJoinReorderContext().setHasLAsscom(false);
            LogicalJoin<Plan, Plan> withChildrenNoContext2 = logicalJoin5.withChildrenNoContext(new LogicalProject(Lists.newArrayList(hashSet), withChildrenNoContext), groupPlan2);
            withChildrenNoContext2.getJoinReorderContext().copyFrom(logicalJoin5.getJoinReorderContext());
            withChildrenNoContext2.getJoinReorderContext().setHasLAsscom(true);
            return CBOUtils.projectOrSelf(ImmutableList.copyOf(logicalJoin5.getOutput()), withChildrenNoContext2);
        }).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public boolean typeChecker(LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> logicalJoin) {
        return SemiJoinSemiJoinTranspose.VALID_TYPE_PAIR_SET.contains(Pair.of(logicalJoin.getJoinType(), ((LogicalJoin) logicalJoin.left().child()).getJoinType()));
    }
}
