package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.TypeUtils;
import org.apache.doris.nereids.util.Utils;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.class */
public class EliminateOuterJoin extends OneRewriteRuleFactory {
    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalFilter(logicalJoin().when(logicalJoin -> {
            return logicalJoin.getJoinType().isOuterJoin();
        })).then(logicalFilter -> {
            LogicalJoin logicalJoin2 = (LogicalJoin) logicalFilter.child();
            ImmutableSet.Builder builder = ImmutableSet.builder();
            HashSet hashSet = new HashSet();
            for (Expression expression : logicalFilter.getConjuncts()) {
                Optional<Slot> isNotNull = TypeUtils.isNotNull(expression);
                if (isNotNull.isPresent()) {
                    hashSet.add(isNotNull.get());
                } else {
                    builder.add(expression);
                }
            }
            boolean isIntersecting = Utils.isIntersecting(logicalJoin2.left().getOutputSet(), hashSet);
            boolean isIntersecting2 = Utils.isIntersecting(logicalJoin2.right().getOutputSet(), hashSet);
            if (!isIntersecting2 && !isIntersecting) {
                return null;
            }
            JoinType tryEliminateOuterJoin = tryEliminateOuterJoin(logicalJoin2.getJoinType(), isIntersecting, isIntersecting2);
            HashSet newHashSet = Sets.newHashSet();
            newHashSet.addAll(logicalFilter.getConjuncts());
            boolean z = false;
            if (!hashSet.isEmpty()) {
                Iterator it = hashSet.iterator();
                while (it.hasNext()) {
                    Not not = new Not(new IsNull((Slot) it.next()));
                    not.isGeneratedIsNotNull = true;
                    z |= newHashSet.add(not);
                }
            }
            if (tryEliminateOuterJoin.isInnerJoin()) {
                Stream<Expression> stream = logicalJoin2.getHashJoinConjuncts().stream();
                Class<EqualPredicate> cls = EqualPredicate.class;
                EqualPredicate.class.getClass();
                boolean anyMatch = z | stream.map((v1) -> {
                    return r2.cast(v1);
                }).map(equalPredicate -> {
                    return JoinUtils.swapEqualToForChildrenOrder(equalPredicate, logicalJoin2.left().getOutputSet());
                }).anyMatch(equalPredicate2 -> {
                    return createIsNotNullIfNecessary(equalPredicate2, newHashSet);
                });
                JoinUtils.JoinSlotCoverageChecker joinSlotCoverageChecker = new JoinUtils.JoinSlotCoverageChecker(logicalJoin2.left().getOutput(), logicalJoin2.right().getOutput());
                Stream<Expression> stream2 = logicalJoin2.getOtherJoinConjuncts().stream();
                Class<EqualPredicate> cls2 = EqualPredicate.class;
                EqualPredicate.class.getClass();
                z = anyMatch | stream2.filter((v1) -> {
                    return r2.isInstance(v1);
                }).filter(expression2 -> {
                    return joinSlotCoverageChecker.isHashJoinCondition((EqualPredicate) expression2);
                }).map(expression3 -> {
                    return JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) expression3, logicalJoin2.left().getOutputSet());
                }).anyMatch(equalPredicate3 -> {
                    return createIsNotNullIfNecessary(equalPredicate3, newHashSet);
                });
            }
            return z ? (Plan) logicalFilter.withConjuncts((Set) newHashSet.stream().collect(ImmutableSet.toImmutableSet())).withChildren(logicalJoin2.withJoinType(tryEliminateOuterJoin)) : (Plan) logicalFilter.withChildren(logicalJoin2.withJoinType(tryEliminateOuterJoin));
        }).toRule(RuleType.ELIMINATE_OUTER_JOIN);
    }

    private JoinType tryEliminateOuterJoin(JoinType joinType, boolean z, boolean z2) {
        return (joinType.isRightOuterJoin() && z) ? JoinType.INNER_JOIN : (joinType.isLeftOuterJoin() && z2) ? JoinType.INNER_JOIN : (joinType.isFullOuterJoin() && z && z2) ? JoinType.INNER_JOIN : (joinType.isFullOuterJoin() && z) ? JoinType.LEFT_OUTER_JOIN : (joinType.isFullOuterJoin() && z2) ? JoinType.RIGHT_OUTER_JOIN : joinType;
    }

    private boolean createIsNotNullIfNecessary(EqualPredicate equalPredicate, Collection<Expression> collection) {
        boolean z = false;
        if (equalPredicate.left().nullable()) {
            Not not = new Not(new IsNull(equalPredicate.left()));
            not.isGeneratedIsNotNull = true;
            z = false | collection.add(not);
        }
        if (equalPredicate.right().nullable()) {
            Not not2 = new Not(new IsNull(equalPredicate.right()));
            not2.isGeneratedIsNotNull = true;
            z |= collection.add(not2);
        }
        return z;
    }
}
