package org.apache.doris.nereids.rules.rewrite;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.nereids.util.Utils;

@DependsRules({MergeFilters.class})
/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/ReorderJoin.class */
public class ReorderJoin extends OneRewriteRuleFactory {
    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalFilter(subTree(LogicalJoin.class, LogicalFilter.class)).whenNot(logicalFilter -> {
            return (logicalFilter.child() instanceof LogicalJoin) && ((LogicalJoin) logicalFilter.child()).isMarkJoin();
        }).thenApply(matchingContext -> {
            if (matchingContext.statementContext.getConnectContext().getSessionVariable().isDisableJoinReorder()) {
                return null;
            }
            LogicalFilter logicalFilter2 = (LogicalFilter) matchingContext.root;
            HashMap newHashMap = Maps.newHashMap();
            Plan joinToMultiJoin = joinToMultiJoin(logicalFilter2, newHashMap);
            Preconditions.checkState(joinToMultiJoin instanceof MultiJoin);
            MultiJoin multiJoin = (MultiJoin) joinToMultiJoin;
            matchingContext.statementContext.setMaxNAryInnerJoin(multiJoin.children().size());
            return multiJoinToJoin(multiJoin, newHashMap);
        }).toRule(RuleType.REORDER_JOIN);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Plan joinToMultiJoin(Plan plan, Map<Plan, JoinHint.JoinHintType> map) {
        LogicalJoin logicalJoin;
        if (nonJoinAndNonFilter(plan) || ((plan instanceof LogicalFilter) && nonJoinAndNonFilter(plan.child(0)))) {
            return plan;
        }
        ArrayList newArrayList = Lists.newArrayList();
        ArrayList newArrayList2 = Lists.newArrayList();
        ArrayList newArrayList3 = Lists.newArrayList();
        if (plan instanceof LogicalFilter) {
            LogicalFilter logicalFilter = (LogicalFilter) plan;
            newArrayList2.addAll(logicalFilter.getConjuncts());
            logicalJoin = (LogicalJoin) logicalFilter.child();
        } else {
            logicalJoin = (LogicalJoin) plan;
        }
        if (logicalJoin.getJoinType().isInnerOrCrossJoin()) {
            newArrayList2.addAll(logicalJoin.getHashJoinConjuncts());
            newArrayList2.addAll(logicalJoin.getOtherJoinConjuncts());
        } else {
            newArrayList3.addAll(logicalJoin.getHashJoinConjuncts());
            newArrayList3.addAll(logicalJoin.getOtherJoinConjuncts());
        }
        map.put(logicalJoin.left(), logicalJoin.getLeftHint());
        Plan joinToMultiJoin = joinToMultiJoin(logicalJoin.left(), map);
        map.put(logicalJoin.right(), logicalJoin.getRightHint());
        Plan joinToMultiJoin2 = joinToMultiJoin(logicalJoin.right(), map);
        if (canCombine(joinToMultiJoin, logicalJoin.getJoinType().isRightJoin() || logicalJoin.getJoinType().isFullOuterJoin())) {
            MultiJoin multiJoin = (MultiJoin) joinToMultiJoin;
            newArrayList.addAll(multiJoin.children());
            newArrayList2.addAll(multiJoin.getJoinFilter());
        } else {
            newArrayList.add(joinToMultiJoin);
        }
        if (canCombine(joinToMultiJoin2, logicalJoin.getJoinType().isLeftJoin() || logicalJoin.getJoinType().isFullOuterJoin())) {
            MultiJoin multiJoin2 = (MultiJoin) joinToMultiJoin2;
            newArrayList.addAll(multiJoin2.children());
            newArrayList2.addAll(multiJoin2.getJoinFilter());
        } else {
            newArrayList.add(joinToMultiJoin2);
        }
        return new MultiJoin(newArrayList, newArrayList2, logicalJoin.getJoinType(), newArrayList3);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v15, types: [org.apache.doris.nereids.trees.plans.Plan] */
    public Plan multiJoinToJoin(MultiJoin multiJoin, Map<Plan, JoinHint.JoinHintType> map) {
        List<Expression> joinFilter;
        Plan plan;
        Plan plan2;
        if (multiJoin.arity() == 1) {
            return PlanUtils.filterOrSelf(ImmutableSet.copyOf(multiJoin.getJoinFilter()), multiJoin.child(0));
        }
        ImmutableList.Builder builder = ImmutableList.builder();
        for (Plan plan3 : multiJoin.children()) {
            if (plan3 instanceof MultiJoin) {
                builder.add(multiJoinToJoin((MultiJoin) plan3, map));
            } else {
                builder.add(plan3);
            }
        }
        MultiJoin withChildren2 = multiJoin.withChildren2((List<Plan>) builder.build());
        if (withChildren2.getJoinType().isInnerOrCrossJoin()) {
            HashSet hashSet = new HashSet(withChildren2.getJoinFilter());
            LogicalJoin<? extends Plan, ? extends Plan> child = withChildren2.child(0);
            HashSet hashSet2 = new HashSet();
            hashSet2.add(0);
            while (hashSet2.size() != withChildren2.children().size()) {
                LogicalJoin<? extends Plan, ? extends Plan> findInnerJoin = findInnerJoin(child, withChildren2.children(), hashSet, hashSet2, map);
                List<Expression> hashJoinConjuncts = findInnerJoin.getHashJoinConjuncts();
                hashSet.getClass();
                hashJoinConjuncts.forEach((v1) -> {
                    r1.remove(v1);
                });
                List<Expression> otherJoinConjuncts = findInnerJoin.getOtherJoinConjuncts();
                hashSet.getClass();
                otherJoinConjuncts.forEach((v1) -> {
                    r1.remove(v1);
                });
                child = findInnerJoin;
            }
            return PlanUtils.filterOrSelf(hashSet, child);
        }
        if (withChildren2.getJoinType().isLeftJoin()) {
            plan2 = withChildren2.child(withChildren2.arity() - 1);
            Set<ExprId> outputExprIdSet = plan2.getOutputExprIdSet();
            Map map2 = (Map) multiJoin.getJoinFilter().stream().collect(Collectors.partitioningBy(expression -> {
                return Utils.isIntersecting(outputExprIdSet, expression.getInputSlotExprIds());
            }));
            joinFilter = (List) map2.get(true);
            plan = multiJoinToJoin(new MultiJoin(withChildren2.children().subList(0, withChildren2.arity() - 1), (List) map2.get(false), JoinType.INNER_JOIN, ExpressionUtils.EMPTY_CONDITION), map);
        } else if (withChildren2.getJoinType().isRightJoin()) {
            plan = withChildren2.child(0);
            Set<ExprId> outputExprIdSet2 = plan.getOutputExprIdSet();
            Map map3 = (Map) multiJoin.getJoinFilter().stream().collect(Collectors.partitioningBy(expression2 -> {
                return Utils.isIntersecting(outputExprIdSet2, expression2.getInputSlotExprIds());
            }));
            joinFilter = (List) map3.get(true);
            plan2 = multiJoinToJoin(new MultiJoin(withChildren2.children().subList(1, withChildren2.arity()), (List) map3.get(false), JoinType.INNER_JOIN, ExpressionUtils.EMPTY_CONDITION), map);
        } else {
            joinFilter = multiJoin.getJoinFilter();
            Preconditions.checkState(withChildren2.arity() == 2);
            List list = (List) withChildren2.children().stream().map(plan4 -> {
                return plan4 instanceof MultiJoin ? multiJoinToJoin((MultiJoin) plan4, map) : plan4;
            }).collect(Collectors.toList());
            plan = (Plan) list.get(0);
            plan2 = (Plan) list.get(1);
        }
        return PlanUtils.filterOrSelf(ImmutableSet.copyOf(joinFilter), new LogicalJoin(withChildren2.getJoinType(), ExpressionUtils.EMPTY_CONDITION, withChildren2.getNotInnerJoinConditions(), JoinHint.fromRightPlanHintType(map.getOrDefault(plan2, JoinHint.JoinHintType.NONE)), Optional.empty(), plan, plan2));
    }

    private static boolean canCombine(Plan plan, boolean z) {
        return (plan instanceof MultiJoin) && ((MultiJoin) plan).getJoinType().isInnerOrCrossJoin() && !z;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v38, types: [java.util.List] */
    private LogicalJoin<? extends Plan, ? extends Plan> findInnerJoin(Plan plan, List<Plan> list, Set<Expression> set, Set<Integer> set2, Map<Plan, JoinHint.JoinHintType> map) {
        ArrayList newArrayList = Lists.newArrayList();
        Set<ExprId> outputExprIdSet = plan.getOutputExprIdSet();
        for (int i = 0; i < list.size(); i++) {
            if (!set2.contains(Integer.valueOf(i))) {
                Plan plan2 = list.get(i);
                Set<ExprId> outputExprIdSet2 = plan2.getOutputExprIdSet();
                Set<ExprId> joinOutputExprIdSet = JoinUtils.getJoinOutputExprIdSet(plan, plan2);
                Pair<List<Expression>, List<Expression>> extractExpressionForHashTable = JoinUtils.extractExpressionForHashTable(plan.getOutput(), plan2.getOutput(), (List) set.stream().filter(expression -> {
                    Set<ExprId> inputSlotExprIds = expression.getInputSlotExprIds();
                    return (outputExprIdSet.containsAll(inputSlotExprIds) || outputExprIdSet2.containsAll(inputSlotExprIds) || !joinOutputExprIdSet.containsAll(inputSlotExprIds)) ? false : true;
                }).collect(Collectors.toList()));
                List list2 = (List) extractExpressionForHashTable.first;
                newArrayList = (List) extractExpressionForHashTable.second;
                if (!list2.isEmpty()) {
                    set2.add(Integer.valueOf(i));
                    return new LogicalJoin<>(JoinType.INNER_JOIN, list2, newArrayList, JoinHint.fromRightPlanHintType(map.getOrDefault(plan2, JoinHint.JoinHintType.NONE)), Optional.empty(), plan, plan2);
                }
            }
        }
        for (int i2 = 0; i2 < list.size(); i2++) {
            if (!set2.contains(Integer.valueOf(i2))) {
                set2.add(Integer.valueOf(i2));
                Plan plan3 = list.get(i2);
                return new LogicalJoin<>(JoinType.CROSS_JOIN, ExpressionUtils.EMPTY_CONDITION, newArrayList, JoinHint.fromRightPlanHintType(map.getOrDefault(plan3, JoinHint.JoinHintType.NONE)), Optional.empty(), plan, plan3);
            }
        }
        throw new RuntimeException("findInnerJoin: can't reach here");
    }

    private boolean nonJoinAndNonFilter(Plan plan) {
        return ((plan instanceof LogicalJoin) || (plan instanceof LogicalFilter)) ? false : true;
    }
}
