package org.apache.doris.nereids.rules.exploration;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.PushdownExpressionsInHashCondition;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.AbstractPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.qe.ConnectContext;

/* loaded from: input_file:org/apache/doris/nereids/rules/exploration/OrExpansion.class */
public class OrExpansion extends OneExplorationRuleFactory {
    public static final OrExpansion INSTANCE = new OrExpansion();
    public static final ImmutableSet<JoinType> supportJoinType = new ImmutableSet.Builder().add(JoinType.INNER_JOIN).add(JoinType.LEFT_ANTI_JOIN).add(JoinType.LEFT_OUTER_JOIN).add(JoinType.FULL_OUTER_JOIN).build();

    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalJoin().when((v0) -> {
            return JoinUtils.shouldNestedLoopJoin(v0);
        }).when(logicalJoin -> {
            return supportJoinType.contains(logicalJoin.getJoinType()) && ConnectContext.get().getSessionVariable().getEnablePipelineEngine();
        }).thenApply(matchingContext -> {
            LogicalJoin<? extends Plan, ? extends Plan> logicalJoin2 = (LogicalJoin) matchingContext.root;
            Preconditions.checkArgument(logicalJoin2.getHashJoinConjuncts().isEmpty(), "Only Expansion nest loop join without hashCond");
            Pair<List<Expression>, List<Expression>> splitOrCondition = splitOrCondition(logicalJoin2);
            if (splitOrCondition == null) {
                return logicalJoin2;
            }
            LogicalCTEProducer<? extends Plan> logicalCTEProducer = new LogicalCTEProducer<>(matchingContext.statementContext.getNextCTEId(), logicalJoin2.left());
            LogicalCTEProducer<? extends Plan> logicalCTEProducer2 = new LogicalCTEProducer<>(matchingContext.statementContext.getNextCTEId(), logicalJoin2.right());
            ArrayList arrayList = new ArrayList();
            if (logicalJoin2.getJoinType().isInnerJoin()) {
                arrayList.addAll(expandInnerJoin(matchingContext.cascadesContext, splitOrCondition, logicalJoin2, logicalCTEProducer, logicalCTEProducer2));
            } else if (logicalJoin2.getJoinType().isOuterJoin()) {
                arrayList.addAll(expandInnerJoin(matchingContext.cascadesContext, splitOrCondition, logicalJoin2, logicalCTEProducer, logicalCTEProducer2));
                arrayList.add(expandLeftAntiJoin(matchingContext.cascadesContext, splitOrCondition, logicalJoin2, logicalCTEProducer, logicalCTEProducer2));
                if (logicalJoin2.getJoinType().equals(JoinType.FULL_OUTER_JOIN)) {
                    arrayList.add(expandLeftAntiJoin(matchingContext.cascadesContext, splitOrCondition, logicalJoin2, logicalCTEProducer2, logicalCTEProducer));
                }
            } else {
                if (!logicalJoin2.getJoinType().equals(JoinType.LEFT_ANTI_JOIN)) {
                    throw new RuntimeException("or-expansion is not supported for " + logicalJoin2);
                }
                arrayList.add(expandLeftAntiJoin(matchingContext.cascadesContext, splitOrCondition, logicalJoin2, logicalCTEProducer, logicalCTEProducer2));
            }
            return new LogicalCTEAnchor(logicalCTEProducer.getCteId(), logicalCTEProducer, new LogicalCTEAnchor(logicalCTEProducer2.getCteId(), logicalCTEProducer2, new LogicalUnion(SetOperation.Qualifier.ALL, new ArrayList(logicalJoin2.getOutput()), (List) arrayList.stream().map(plan -> {
                Stream<Slot> stream = plan.getOutput().stream();
                Class<SlotReference> cls = SlotReference.class;
                SlotReference.class.getClass();
                return (ImmutableList) stream.map((v1) -> {
                    return r1.cast(v1);
                }).collect(ImmutableList.toImmutableList());
            }).collect(ImmutableList.toImmutableList()), ImmutableList.of(), false, arrayList)));
        }).toRule(RuleType.OR_EXPANSION);
    }

    @Nullable
    private Pair<List<Expression>, List<Expression>> splitOrCondition(LogicalJoin<? extends Plan, ? extends Plan> logicalJoin) {
        ArrayList<Expression> arrayList = new ArrayList(logicalJoin.getOtherJoinConjuncts());
        for (Expression expression : arrayList) {
            List<Expression> extractDisjunction = ExpressionUtils.extractDisjunction(expression);
            if (((List) JoinUtils.extractExpressionForHashTable(logicalJoin.left().getOutput(), logicalJoin.right().getOutput(), extractDisjunction).second).isEmpty()) {
                arrayList.remove(expression);
                return Pair.of(extractDisjunction, arrayList);
            }
        }
        return null;
    }

    private Plan expandLeftAntiJoin(CascadesContext cascadesContext, Pair<List<Expression>, List<Expression>> pair, LogicalJoin<? extends Plan, ? extends Plan> logicalJoin, LogicalCTEProducer<? extends Plan> logicalCTEProducer, LogicalCTEProducer<? extends Plan> logicalCTEProducer2) {
        LogicalCTEConsumer logicalCTEConsumer = new LogicalCTEConsumer(cascadesContext.getStatementContext().getNextRelationId(), logicalCTEProducer.getCteId(), "", logicalCTEProducer);
        LogicalCTEConsumer logicalCTEConsumer2 = new LogicalCTEConsumer(cascadesContext.getStatementContext().getNextRelationId(), logicalCTEProducer2.getCteId(), "", logicalCTEProducer2);
        cascadesContext.putCTEIdToConsumer(logicalCTEConsumer);
        cascadesContext.putCTEIdToConsumer(logicalCTEConsumer2);
        HashMap hashMap = new HashMap(logicalCTEConsumer.getProducerToConsumerOutputMap());
        hashMap.putAll(logicalCTEConsumer2.getProducerToConsumerOutputMap());
        List list = (List) pair.first;
        List list2 = (List) ((List) pair.second).stream().map(expression -> {
            return expression.rewriteUp(expression -> {
                return hashMap.containsKey(expression) ? (Expression) hashMap.get(expression) : expression;
            });
        }).collect(Collectors.toList());
        Expression rewriteUp = ((Expression) list.get(0)).rewriteUp(expression2 -> {
            return hashMap.containsKey(expression2) ? (Expression) hashMap.get(expression2) : expression2;
        });
        AbstractPlan logicalJoin2 = new LogicalJoin(JoinType.LEFT_ANTI_JOIN, Lists.newArrayList(new Expression[]{rewriteUp}), list2, logicalJoin.getHint(), logicalJoin.getMarkJoinSlotReference(), logicalCTEConsumer, logicalCTEConsumer2);
        if (rewriteUp.children().stream().anyMatch(expression3 -> {
            return !(expression3 instanceof Slot);
        })) {
            logicalJoin2 = new LogicalProject(new ArrayList(logicalJoin2.getOutput()), PushdownExpressionsInHashCondition.pushDownHashExpression((LogicalJoin) logicalJoin2));
        }
        for (int i = 1; i < list.size(); i++) {
            Expression expression4 = (Expression) list.get(i);
            LogicalCTEConsumer logicalCTEConsumer3 = new LogicalCTEConsumer(cascadesContext.getStatementContext().getNextRelationId(), logicalCTEProducer2.getCteId(), "", logicalCTEProducer2);
            cascadesContext.putCTEIdToConsumer(logicalCTEConsumer3);
            HashMap hashMap2 = new HashMap(logicalCTEConsumer.getProducerToConsumerOutputMap());
            hashMap2.putAll(logicalCTEConsumer3.getProducerToConsumerOutputMap());
            Expression rewriteUp2 = expression4.rewriteUp(expression5 -> {
                return hashMap2.containsKey(expression5) ? (Expression) hashMap2.get(expression5) : expression5;
            });
            logicalJoin2 = new LogicalJoin(JoinType.LEFT_ANTI_JOIN, Lists.newArrayList(new Expression[]{rewriteUp2}), new ArrayList(), logicalJoin.getHint(), logicalJoin.getMarkJoinSlotReference(), logicalJoin2, logicalCTEConsumer3);
            if (rewriteUp2.children().stream().anyMatch(expression6 -> {
                return !(expression6 instanceof Slot);
            })) {
                logicalJoin2 = PushdownExpressionsInHashCondition.pushDownHashExpression((LogicalJoin) logicalJoin2);
            }
        }
        AbstractPlan abstractPlan = logicalJoin2;
        Stream<Slot> stream = logicalJoin.getOutput().stream();
        hashMap.getClass();
        return new LogicalProject((List) stream.map((v1) -> {
            return r1.get(v1);
        }).map(slot -> {
            return abstractPlan.getOutputSet().contains(slot) ? slot : new Alias(new NullLiteral(slot.getDataType()), slot.getName());
        }).collect(Collectors.toList()), logicalJoin2);
    }

    private List<Plan> expandInnerJoin(CascadesContext cascadesContext, Pair<List<Expression>, List<Expression>> pair, LogicalJoin<? extends Plan, ? extends Plan> logicalJoin, LogicalCTEProducer<? extends Plan> logicalCTEProducer, LogicalCTEProducer<? extends Plan> logicalCTEProducer2) {
        List<Expression> list = (List) pair.first;
        List list2 = (List) pair.second;
        List<Expression> list3 = (List) list.stream().map(Not::new).collect(Collectors.toList());
        ArrayList newArrayList = Lists.newArrayList();
        for (int i = 0; i < list.size(); i++) {
            Pair<List<Expression>, List<Expression>> extractHashAndOtherConditions = extractHashAndOtherConditions(i, list, list3);
            ((List) extractHashAndOtherConditions.second).addAll(list2);
            LogicalCTEConsumer logicalCTEConsumer = new LogicalCTEConsumer(cascadesContext.getStatementContext().getNextRelationId(), logicalCTEProducer.getCteId(), "", logicalCTEProducer);
            LogicalCTEConsumer logicalCTEConsumer2 = new LogicalCTEConsumer(cascadesContext.getStatementContext().getNextRelationId(), logicalCTEProducer2.getCteId(), "", logicalCTEProducer2);
            cascadesContext.putCTEIdToConsumer(logicalCTEConsumer);
            cascadesContext.putCTEIdToConsumer(logicalCTEConsumer2);
            HashMap hashMap = new HashMap(logicalCTEConsumer.getProducerToConsumerOutputMap());
            hashMap.putAll(logicalCTEConsumer2.getProducerToConsumerOutputMap());
            LogicalJoin logicalJoin2 = new LogicalJoin(JoinType.INNER_JOIN, (List) ((List) extractHashAndOtherConditions.first).stream().map(expression -> {
                return expression.rewriteUp(expression -> {
                    return hashMap.containsKey(expression) ? (Expression) hashMap.get(expression) : expression;
                });
            }).collect(Collectors.toList()), (List) ((List) extractHashAndOtherConditions.second).stream().map(expression2 -> {
                return expression2.rewriteUp(expression2 -> {
                    return hashMap.containsKey(expression2) ? (Expression) hashMap.get(expression2) : expression2;
                });
            }).collect(Collectors.toList()), logicalJoin.getHint(), logicalJoin.getMarkJoinSlotReference(), logicalCTEConsumer, logicalCTEConsumer2);
            if (logicalJoin2.getHashJoinConjuncts().stream().anyMatch(expression3 -> {
                return expression3.children().stream().anyMatch(expression3 -> {
                    return !(expression3 instanceof Slot);
                });
            })) {
                newArrayList.add(new LogicalProject(new ArrayList(logicalJoin2.getOutput()), PushdownExpressionsInHashCondition.pushDownHashExpression(logicalJoin2)));
            } else {
                newArrayList.add(logicalJoin2);
            }
        }
        return newArrayList;
    }

    private Pair<List<Expression>, List<Expression>> extractHashAndOtherConditions(int i, List<Expression> list, List<Expression> list2) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(list2.get(i2));
        }
        return Pair.of(Lists.newArrayList(new Expression[]{list.get(i)}), arrayList);
    }
}
