package org.apache.doris.nereids.rules.rewrite;

import java.util.HashSet;
import java.util.Set;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.class */
public class InferJoinNotNull extends OneRewriteRuleFactory {
    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalJoin(any(), any()).when(logicalJoin -> {
            return logicalJoin.getJoinType().isInnerJoin() || logicalJoin.getJoinType().isSemiJoin();
        }).thenApply(matchingContext -> {
            LogicalJoin logicalJoin2 = (LogicalJoin) matchingContext.root;
            HashSet hashSet = new HashSet();
            hashSet.addAll(logicalJoin2.getHashJoinConjuncts());
            hashSet.addAll(logicalJoin2.getOtherJoinConjuncts());
            Plan left = logicalJoin2.left();
            Plan right = logicalJoin2.right();
            if (logicalJoin2.getJoinType().isInnerJoin()) {
                Set<Expression> inferNotNull = ExpressionUtils.inferNotNull(hashSet, logicalJoin2.left().getOutputSet(), matchingContext.cascadesContext);
                Set<Expression> inferNotNull2 = ExpressionUtils.inferNotNull(hashSet, logicalJoin2.right().getOutputSet(), matchingContext.cascadesContext);
                left = PlanUtils.filterOrSelf(inferNotNull, logicalJoin2.left());
                right = PlanUtils.filterOrSelf(inferNotNull2, logicalJoin2.right());
            } else if (logicalJoin2.getJoinType() == JoinType.LEFT_SEMI_JOIN) {
                left = PlanUtils.filterOrSelf(ExpressionUtils.inferNotNull(hashSet, logicalJoin2.left().getOutputSet(), matchingContext.cascadesContext), logicalJoin2.left());
            } else {
                right = PlanUtils.filterOrSelf(ExpressionUtils.inferNotNull(hashSet, logicalJoin2.right().getOutputSet(), matchingContext.cascadesContext), logicalJoin2.right());
            }
            if (left.equals(logicalJoin2.left()) && right.equals(logicalJoin2.right())) {
                return null;
            }
            return (Plan) logicalJoin2.withChildren(left, right);
        }).toRule(RuleType.INFER_JOIN_NOT_NULL);
    }
}
