package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableList;
import java.util.List;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.JoinUtils;

@DependsRules({PushFilterInsideJoin.class})
/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/FindHashConditionForJoin.class */
public class FindHashConditionForJoin extends OneRewriteRuleFactory {
    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalJoin().then(logicalJoin -> {
            Pair<List<Expression>, List<Expression>> extractExpressionForHashTable = JoinUtils.extractExpressionForHashTable(logicalJoin.left().getOutput(), logicalJoin.right().getOutput(), logicalJoin.getOtherJoinConjuncts());
            List list = (List) extractExpressionForHashTable.first;
            List list2 = (List) extractExpressionForHashTable.second;
            if (list.isEmpty()) {
                return logicalJoin;
            }
            ImmutableList build = new ImmutableList.Builder().addAll(logicalJoin.getHashJoinConjuncts()).addAll(list).build();
            JoinType joinType = logicalJoin.getJoinType();
            if (joinType == JoinType.CROSS_JOIN && !build.isEmpty()) {
                joinType = JoinType.INNER_JOIN;
            }
            return new LogicalJoin(joinType, (List<Expression>) build, (List<Expression>) list2, logicalJoin.getHint(), logicalJoin.getMarkJoinSlotReference(), logicalJoin.children());
        }).toRule(RuleType.FIND_HASH_CONDITION_FOR_JOIN);
    }
}
