package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/AdjustNullable.class */
public class AdjustNullable extends DefaultPlanRewriter<Map<ExprId, Slot>> implements CustomRewriter {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/AdjustNullable$SlotReferenceReplacer.class */
    public static class SlotReferenceReplacer extends DefaultExpressionRewriter<Map<ExprId, Slot>> {
        public static SlotReferenceReplacer INSTANCE = new SlotReferenceReplacer();

        private SlotReferenceReplacer() {
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
        public Expression visitSlotReference(SlotReference slotReference, Map<ExprId, Slot> map) {
            return map.containsKey(slotReference.getExprId()) ? slotReference.withNullable(map.get(slotReference.getExprId()).nullable()) : slotReference;
        }
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.CustomRewriter
    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
        return (Plan) plan.accept(this, Maps.newHashMap());
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter, org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public Plan visit(Plan plan, Map<ExprId, Slot> map) {
        LogicalPlan recomputeLogicalProperties = ((LogicalPlan) super.visit(plan, (Plan) map)).recomputeLogicalProperties();
        recomputeLogicalProperties.getOutputSet().forEach(slot -> {
        });
        return recomputeLogicalProperties;
    }

    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> logicalAggregate, Map<ExprId, Slot> map) {
        LogicalAggregate logicalAggregate2 = (LogicalAggregate) super.visit((Plan) logicalAggregate, (LogicalAggregate<? extends Plan>) map);
        List<NamedExpression> updateExpressions = updateExpressions(logicalAggregate2.getOutputExpressions(), map);
        List<Expression> updateExpressions2 = updateExpressions(logicalAggregate2.getGroupByExpressions(), map);
        updateExpressions.forEach(namedExpression -> {
        });
        return logicalAggregate2.withGroupByAndOutput(updateExpressions2, updateExpressions);
    }

    public Plan visitLogicalFilter(LogicalFilter<? extends Plan> logicalFilter, Map<ExprId, Slot> map) {
        LogicalFilter logicalFilter2 = (LogicalFilter) super.visit((Plan) logicalFilter, (LogicalFilter<? extends Plan>) map);
        return logicalFilter2.withConjuncts(updateExpressions(logicalFilter2.getConjuncts(), map)).recomputeLogicalProperties();
    }

    public Plan visitLogicalGenerate(LogicalGenerate<? extends Plan> logicalGenerate, Map<ExprId, Slot> map) {
        LogicalGenerate logicalGenerate2 = (LogicalGenerate) super.visit((Plan) logicalGenerate, (LogicalGenerate<? extends Plan>) map);
        LogicalPlan recomputeLogicalProperties = logicalGenerate2.withGenerators(updateExpressions(logicalGenerate2.getGenerators(), map)).recomputeLogicalProperties();
        recomputeLogicalProperties.getOutputSet().forEach(slot -> {
        });
        return recomputeLogicalProperties;
    }

    public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> logicalJoin, Map<ExprId, Slot> map) {
        LogicalJoin logicalJoin2 = (LogicalJoin) super.visit((Plan) logicalJoin, (LogicalJoin<? extends Plan, ? extends Plan>) map);
        List<Expression> updateExpressions = updateExpressions(logicalJoin2.getHashJoinConjuncts(), map);
        logicalJoin2.getOutputSet().forEach(slot -> {
        });
        return logicalJoin2.withJoinConjuncts(updateExpressions, updateExpressions(logicalJoin2.getOtherJoinConjuncts(), map)).recomputeLogicalProperties();
    }

    public Plan visitLogicalProject(LogicalProject<? extends Plan> logicalProject, Map<ExprId, Slot> map) {
        LogicalProject logicalProject2 = (LogicalProject) super.visit((Plan) logicalProject, (LogicalProject<? extends Plan>) map);
        List<NamedExpression> updateExpressions = updateExpressions(logicalProject2.getProjects(), map);
        updateExpressions.forEach(namedExpression -> {
        });
        return logicalProject2.withProjects(updateExpressions);
    }

    public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan> logicalRepeat, Map<ExprId, Slot> map) {
        LogicalRepeat logicalRepeat2 = (LogicalRepeat) super.visit((Plan) logicalRepeat, (LogicalRepeat<? extends Plan>) map);
        ImmutableSet copyOf = ImmutableSet.copyOf(ExpressionUtils.flatExpressions(logicalRepeat2.getGroupingSets()));
        ArrayList newArrayList = Lists.newArrayList();
        for (NamedExpression namedExpression : logicalRepeat2.getOutputExpressions()) {
            if (copyOf.contains(namedExpression)) {
                newArrayList.add(namedExpression);
            } else {
                newArrayList.add(updateExpression(namedExpression, map));
            }
        }
        newArrayList.forEach(namedExpression2 -> {
        });
        return logicalRepeat2.withGroupSetsAndOutput(logicalRepeat2.getGroupingSets(), newArrayList).recomputeLogicalProperties();
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public Plan visitLogicalSetOperation(LogicalSetOperation logicalSetOperation, Map<ExprId, Slot> map) {
        LogicalSetOperation logicalSetOperation2 = (LogicalSetOperation) super.visit((Plan) logicalSetOperation, (LogicalSetOperation) map);
        if (logicalSetOperation2.children().isEmpty()) {
            return logicalSetOperation2;
        }
        List list = (List) logicalSetOperation2.child(0).getOutput().stream().map((v0) -> {
            return v0.nullable();
        }).collect(Collectors.toList());
        ImmutableList.Builder builder = ImmutableList.builder();
        for (int i = 0; i < logicalSetOperation2.arity(); i++) {
            List<Slot> output = logicalSetOperation2.child(i).getOutput();
            List<SlotReference> regularChildOutput = logicalSetOperation2.getRegularChildOutput(i);
            ImmutableList.Builder builder2 = ImmutableList.builder();
            for (int i2 = 0; i2 < regularChildOutput.size(); i2++) {
                Iterator<Slot> it = output.iterator();
                while (true) {
                    if (it.hasNext()) {
                        Slot next = it.next();
                        if (next.getExprId().equals(regularChildOutput.get(i2).getExprId())) {
                            list.set(i2, Boolean.valueOf(next.nullable() || ((Boolean) list.get(i2)).booleanValue()));
                            builder2.add((SlotReference) next);
                        }
                    }
                }
            }
            builder.add(builder2.build());
        }
        if (logicalSetOperation2 instanceof LogicalUnion) {
            for (List<NamedExpression> list2 : ((LogicalUnion) logicalSetOperation2).getConstantExprsList()) {
                for (int i3 = 0; i3 < list2.size(); i3++) {
                    if (list2.get(i3).nullable()) {
                        list.set(i3, true);
                    }
                }
            }
        }
        List<NamedExpression> outputs = logicalSetOperation2.getOutputs();
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(outputs.size());
        for (int i4 = 0; i4 < list.size(); i4++) {
            NamedExpression namedExpression = outputs.get(i4);
            Slot slot = namedExpression instanceof Alias ? (Slot) ((Alias) namedExpression).child() : (Slot) namedExpression;
            if (((Boolean) list.get(i4)).booleanValue()) {
                slot = slot.withNullable(true);
            }
            newArrayListWithCapacity.add(namedExpression instanceof Alias ? (NamedExpression) namedExpression.withChildren(slot) : slot);
        }
        newArrayListWithCapacity.forEach(namedExpression2 -> {
        });
        return logicalSetOperation2.withNewOutputs(newArrayListWithCapacity).withChildrenAndTheirOutputs(logicalSetOperation2.children(), builder.build()).recomputeLogicalProperties();
    }

    public Plan visitLogicalSort(LogicalSort<? extends Plan> logicalSort, Map<ExprId, Slot> map) {
        LogicalSort logicalSort2 = (LogicalSort) super.visit((Plan) logicalSort, (LogicalSort<? extends Plan>) map);
        return logicalSort2.withOrderKeys((List) logicalSort2.getOrderKeys().stream().map(orderKey -> {
            return orderKey.withExpression(updateExpression(orderKey.getExpr(), map));
        }).collect(ImmutableList.toImmutableList())).recomputeLogicalProperties();
    }

    public Plan visitLogicalTopN(LogicalTopN<? extends Plan> logicalTopN, Map<ExprId, Slot> map) {
        LogicalTopN logicalTopN2 = (LogicalTopN) super.visit((Plan) logicalTopN, (LogicalTopN<? extends Plan>) map);
        return logicalTopN2.withOrderKeys((List) logicalTopN2.getOrderKeys().stream().map(orderKey -> {
            return orderKey.withExpression(updateExpression(orderKey.getExpr(), map));
        }).collect(ImmutableList.toImmutableList())).recomputeLogicalProperties();
    }

    public Plan visitLogicalWindow(LogicalWindow<? extends Plan> logicalWindow, Map<ExprId, Slot> map) {
        LogicalWindow logicalWindow2 = (LogicalWindow) super.visit((Plan) logicalWindow, (LogicalWindow<? extends Plan>) map);
        List<NamedExpression> updateExpressions = updateExpressions(logicalWindow2.getWindowExpressions(), map);
        updateExpressions.forEach(namedExpression -> {
        });
        return logicalWindow2.withExpression(updateExpressions, (Plan) logicalWindow2.child());
    }

    public Plan visitLogicalPartitionTopN(LogicalPartitionTopN<? extends Plan> logicalPartitionTopN, Map<ExprId, Slot> map) {
        LogicalPartitionTopN logicalPartitionTopN2 = (LogicalPartitionTopN) super.visit((Plan) logicalPartitionTopN, (LogicalPartitionTopN<? extends Plan>) map);
        return logicalPartitionTopN2.withPartitionKeysAndOrderKeys(updateExpressions(logicalPartitionTopN2.getPartitionKeys(), map), updateExpressions(logicalPartitionTopN2.getOrderKeys(), map));
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public Plan visitLogicalCTEConsumer(LogicalCTEConsumer logicalCTEConsumer, Map<ExprId, Slot> map) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        for (Slot slot : logicalCTEConsumer.getConsumerToProducerOutputMap().values()) {
            Slot slot2 = (Slot) updateExpression(slot, map);
            Slot withNullable = logicalCTEConsumer.getProducerToConsumerOutputMap().get(slot).withNullable(slot2.nullable());
            linkedHashMap2.put(slot2, withNullable);
            linkedHashMap.put(withNullable, slot2);
            map.put(withNullable.getExprId(), withNullable);
        }
        return logicalCTEConsumer.withTwoMaps(linkedHashMap, linkedHashMap2);
    }

    private <T extends Expression> T updateExpression(T t, Map<ExprId, Slot> map) {
        return (T) t.rewriteDownShortCircuit(expression -> {
            return (Expression) expression.accept(SlotReferenceReplacer.INSTANCE, map);
        });
    }

    private <T extends Expression> List<T> updateExpressions(List<T> list, Map<ExprId, Slot> map) {
        return (List) list.stream().map(expression -> {
            return updateExpression(expression, map);
        }).collect(ImmutableList.toImmutableList());
    }

    private <T extends Expression> Set<T> updateExpressions(Set<T> set, Map<ExprId, Slot> map) {
        return (Set) set.stream().map(expression -> {
            return updateExpression(expression, map);
        }).collect(ImmutableSet.toImmutableSet());
    }

    private Map<ExprId, Slot> collectChildrenOutputMap(LogicalPlan logicalPlan) {
        return (Map) logicalPlan.children().stream().map((v0) -> {
            return v0.getOutputSet();
        }).flatMap((v0) -> {
            return v0.stream();
        }).collect(Collectors.toMap((v0) -> {
            return v0.getExprId();
        }, slot -> {
            return slot;
        }));
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ Plan visitLogicalWindow(LogicalWindow logicalWindow, Object obj) {
        return visitLogicalWindow((LogicalWindow<? extends Plan>) logicalWindow, (Map<ExprId, Slot>) obj);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ Plan visitLogicalTopN(LogicalTopN logicalTopN, Object obj) {
        return visitLogicalTopN((LogicalTopN<? extends Plan>) logicalTopN, (Map<ExprId, Slot>) obj);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ Plan visitLogicalSort(LogicalSort logicalSort, Object obj) {
        return visitLogicalSort((LogicalSort<? extends Plan>) logicalSort, (Map<ExprId, Slot>) obj);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ Plan visitLogicalRepeat(LogicalRepeat logicalRepeat, Object obj) {
        return visitLogicalRepeat((LogicalRepeat<? extends Plan>) logicalRepeat, (Map<ExprId, Slot>) obj);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ Plan visitLogicalProject(LogicalProject logicalProject, Object obj) {
        return visitLogicalProject((LogicalProject<? extends Plan>) logicalProject, (Map<ExprId, Slot>) obj);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ Plan visitLogicalPartitionTopN(LogicalPartitionTopN logicalPartitionTopN, Object obj) {
        return visitLogicalPartitionTopN((LogicalPartitionTopN<? extends Plan>) logicalPartitionTopN, (Map<ExprId, Slot>) obj);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ Plan visitLogicalJoin(LogicalJoin logicalJoin, Object obj) {
        return visitLogicalJoin((LogicalJoin<? extends Plan, ? extends Plan>) logicalJoin, (Map<ExprId, Slot>) obj);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ Plan visitLogicalGenerate(LogicalGenerate logicalGenerate, Object obj) {
        return visitLogicalGenerate((LogicalGenerate<? extends Plan>) logicalGenerate, (Map<ExprId, Slot>) obj);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ Plan visitLogicalFilter(LogicalFilter logicalFilter, Object obj) {
        return visitLogicalFilter((LogicalFilter<? extends Plan>) logicalFilter, (Map<ExprId, Slot>) obj);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ Plan visitLogicalAggregate(LogicalAggregate logicalAggregate, Object obj) {
        return visitLogicalAggregate((LogicalAggregate<? extends Plan>) logicalAggregate, (Map<ExprId, Slot>) obj);
    }
}
