package org.apache.doris.nereids.util;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.doris.catalog.ColocateTableIndex;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.properties.DistributionSpec;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapContains;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;

/* loaded from: input_file:org/apache/doris/nereids/util/JoinUtils.class */
public class JoinUtils {

    /* loaded from: input_file:org/apache/doris/nereids/util/JoinUtils$JoinSlotCoverageChecker.class */
    public static final class JoinSlotCoverageChecker {
        Set<ExprId> leftExprIds;
        Set<ExprId> rightExprIds;

        public JoinSlotCoverageChecker(List<Slot> list, List<Slot> list2) {
            this.leftExprIds = (Set) list.stream().map((v0) -> {
                return v0.getExprId();
            }).collect(Collectors.toSet());
            this.rightExprIds = (Set) list2.stream().map((v0) -> {
                return v0.getExprId();
            }).collect(Collectors.toSet());
        }

        public boolean isHashJoinCondition(EqualPredicate equalPredicate) {
            Set<ExprId> inputSlotExprIds = equalPredicate.left().getInputSlotExprIds();
            if (inputSlotExprIds.isEmpty()) {
                return false;
            }
            Set<ExprId> inputSlotExprIds2 = equalPredicate.right().getInputSlotExprIds();
            if (inputSlotExprIds2.isEmpty()) {
                return false;
            }
            return (this.leftExprIds.containsAll(inputSlotExprIds) && this.rightExprIds.containsAll(inputSlotExprIds2)) || (this.leftExprIds.containsAll(inputSlotExprIds2) && this.rightExprIds.containsAll(inputSlotExprIds));
        }
    }

    public static boolean couldShuffle(Join join) {
        return (join.getJoinType().isCrossJoin() || join.getJoinType().isNullAwareLeftAntiJoin() || join.isMarkJoin()) ? false : true;
    }

    public static boolean couldBroadcast(Join join) {
        return (join.getJoinType().isRightJoin() || join.getJoinType().isFullOuterJoin()) ? false : true;
    }

    public static Pair<List<Expression>, List<Expression>> extractExpressionForHashTable(List<Slot> list, List<Slot> list2, List<Expression> list3) {
        JoinSlotCoverageChecker joinSlotCoverageChecker = new JoinSlotCoverageChecker(list, list2);
        Map map = (Map) list3.stream().collect(Collectors.groupingBy(expression -> {
            return Boolean.valueOf((expression instanceof EqualPredicate) && joinSlotCoverageChecker.isHashJoinCondition((EqualPredicate) expression));
        }));
        return Pair.of(map.getOrDefault(true, ImmutableList.of()), map.getOrDefault(false, ImmutableList.of()));
    }

    public static List<Expression> extractBitmapRuntimeFilterConditions(List<Slot> list, List<Slot> list2, List<Expression> list3) {
        ArrayList newArrayList = Lists.newArrayList();
        for (Expression expression : list3) {
            BitmapContains bitmapContains = null;
            if (expression instanceof Not) {
                List<Expression> extractConjunction = ExpressionUtils.extractConjunction(expression.child(0));
                if (extractConjunction.size() == 1 && (extractConjunction.get(0) instanceof BitmapContains)) {
                    bitmapContains = (BitmapContains) extractConjunction.get(0);
                }
            } else if (expression instanceof BitmapContains) {
                bitmapContains = (BitmapContains) expression;
            }
            if (bitmapContains != null) {
                Expression child = bitmapContains.child(1);
                Class<Slot> cls = Slot.class;
                Slot.class.getClass();
                if (list.containsAll((Collection) child.collect((v1) -> {
                    return r2.isInstance(v1);
                }))) {
                    Expression child2 = bitmapContains.child(0);
                    Class<Slot> cls2 = Slot.class;
                    Slot.class.getClass();
                    if (list2.containsAll((Collection) child2.collect((v1) -> {
                        return r2.isInstance(v1);
                    }))) {
                        newArrayList.add(expression);
                    }
                }
            }
        }
        return newArrayList;
    }

    public static boolean shouldNestedLoopJoin(Join join) {
        return join.getHashJoinConjuncts().isEmpty();
    }

    public static boolean shouldNestedLoopJoin(JoinType joinType, List<Expression> list) {
        return list.isEmpty();
    }

    public static EqualPredicate swapEqualToForChildrenOrder(EqualPredicate equalPredicate, Set<Slot> set) {
        return set.containsAll(equalPredicate.left().getInputSlots()) ? equalPredicate : equalPredicate.commute();
    }

    public static boolean shouldBucketShuffleJoin(AbstractPhysicalJoin<PhysicalPlan, PhysicalPlan> abstractPhysicalJoin) {
        DistributionSpec distributionSpec = ((PhysicalPlan) abstractPhysicalJoin.right()).getPhysicalProperties().getDistributionSpec();
        return (distributionSpec instanceof DistributionSpecHash) && ((DistributionSpecHash) distributionSpec).getShuffleType() == DistributionSpecHash.ShuffleType.STORAGE_BUCKETED;
    }

    public static boolean shouldBroadcastJoin(AbstractPhysicalJoin<PhysicalPlan, PhysicalPlan> abstractPhysicalJoin) {
        PhysicalPlan physicalPlan = (PhysicalPlan) abstractPhysicalJoin.right();
        if (physicalPlan instanceof PhysicalDistribute) {
            return ((PhysicalDistribute) physicalPlan).getDistributionSpec() instanceof DistributionSpecReplicated;
        }
        return false;
    }

    public static boolean shouldColocateJoin(AbstractPhysicalJoin<PhysicalPlan, PhysicalPlan> abstractPhysicalJoin) {
        if (ConnectContext.get() == null || ConnectContext.get().getSessionVariable().isDisableColocatePlan()) {
            return false;
        }
        DistributionSpec distributionSpec = ((PhysicalPlan) abstractPhysicalJoin.left()).getPhysicalProperties().getDistributionSpec();
        DistributionSpec distributionSpec2 = ((PhysicalPlan) abstractPhysicalJoin.right()).getPhysicalProperties().getDistributionSpec();
        if ((distributionSpec instanceof DistributionSpecHash) && (distributionSpec2 instanceof DistributionSpecHash)) {
            return couldColocateJoin((DistributionSpecHash) distributionSpec, (DistributionSpecHash) distributionSpec2);
        }
        return false;
    }

    public static boolean couldColocateJoin(DistributionSpecHash distributionSpecHash, DistributionSpecHash distributionSpecHash2) {
        if (ConnectContext.get() == null || ConnectContext.get().getSessionVariable().isDisableColocatePlan() || distributionSpecHash.getShuffleType() != DistributionSpecHash.ShuffleType.NATURAL || distributionSpecHash2.getShuffleType() != DistributionSpecHash.ShuffleType.NATURAL) {
            return false;
        }
        long tableId = distributionSpecHash.getTableId();
        long tableId2 = distributionSpecHash2.getTableId();
        Set<Long> partitionIds = distributionSpecHash.getPartitionIds();
        boolean z = ((tableId > tableId2 ? 1 : (tableId == tableId2 ? 0 : -1)) == 0 && (distributionSpecHash.getSelectedIndexId() > (-1L) ? 1 : (distributionSpecHash.getSelectedIndexId() == (-1L) ? 0 : -1)) != 0 && (distributionSpecHash2.getSelectedIndexId() > (-1L) ? 1 : (distributionSpecHash2.getSelectedIndexId() == (-1L) ? 0 : -1)) != 0 && (distributionSpecHash.getSelectedIndexId() > distributionSpecHash2.getSelectedIndexId() ? 1 : (distributionSpecHash.getSelectedIndexId() == distributionSpecHash2.getSelectedIndexId() ? 0 : -1)) == 0) && partitionIds.equals(distributionSpecHash2.getPartitionIds()) && partitionIds.size() <= 1;
        ColocateTableIndex currentColocateIndex = Env.getCurrentColocateIndex();
        if (z) {
            return true;
        }
        return currentColocateIndex.isSameGroup(tableId, tableId2) && !currentColocateIndex.isGroupUnstable(currentColocateIndex.getGroup(tableId));
    }

    public static Set<ExprId> getJoinOutputExprIdSet(Plan plan, Plan plan2) {
        HashSet hashSet = new HashSet();
        hashSet.addAll(plan.getOutputExprIdSet());
        hashSet.addAll(plan2.getOutputExprIdSet());
        return hashSet;
    }

    private static List<Slot> applyNullable(List<Slot> list, boolean z) {
        return (List) list.stream().map(slot -> {
            return slot.withNullable(z);
        }).collect(ImmutableList.toImmutableList());
    }

    public static List<Slot> getJoinOutput(JoinType joinType, Plan plan, Plan plan2) {
        switch (joinType) {
            case LEFT_SEMI_JOIN:
            case LEFT_ANTI_JOIN:
            case NULL_AWARE_LEFT_ANTI_JOIN:
                return ImmutableList.copyOf(plan.getOutput());
            case RIGHT_SEMI_JOIN:
            case RIGHT_ANTI_JOIN:
                return ImmutableList.copyOf(plan2.getOutput());
            case LEFT_OUTER_JOIN:
                return ImmutableList.builder().addAll(plan.getOutput()).addAll(applyNullable(plan2.getOutput(), true)).build();
            case RIGHT_OUTER_JOIN:
                return ImmutableList.builder().addAll(applyNullable(plan.getOutput(), true)).addAll(plan2.getOutput()).build();
            case FULL_OUTER_JOIN:
                return ImmutableList.builder().addAll(applyNullable(plan.getOutput(), true)).addAll(applyNullable(plan2.getOutput(), true)).build();
            default:
                return ImmutableList.builder().addAll(plan.getOutput()).addAll(plan2.getOutput()).build();
        }
    }
}
