package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalExcept;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSink;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.logical.OutputPrunable;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/ColumnPruning.class */
public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements CustomRewriter {

    /* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/ColumnPruning$PruneContext.class */
    public static class PruneContext {
        public Set<Slot> requiredSlots;
        public Optional<Plan> parent;

        public PruneContext(Set<Slot> set, Plan plan) {
            this.requiredSlots = set;
            this.parent = Optional.ofNullable(plan);
        }
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.CustomRewriter
    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
        return (Plan) plan.accept(this, new PruneContext(plan.getOutputSet(), null));
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter, org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public Plan visit(Plan plan, PruneContext pruneContext) {
        if (!(plan instanceof OutputPrunable)) {
            return pruneChildren(plan, pruneContext.requiredSlots);
        }
        OutputPrunable outputPrunable = (OutputPrunable) plan;
        List<NamedExpression> outputs = outputPrunable.getOutputs();
        outputPrunable.getClass();
        return pruneChildren(pruneOutput(plan, outputs, outputPrunable::pruneOutputs, pruneContext));
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public Plan visitLogicalUnion(LogicalUnion logicalUnion, PruneContext pruneContext) {
        if (logicalUnion.getQualifier() == SetOperation.Qualifier.DISTINCT) {
            return skipPruneThisAndFirstLevelChildren(logicalUnion);
        }
        List<NamedExpression> outputs = logicalUnion.getOutputs();
        logicalUnion.getClass();
        LogicalUnion logicalUnion2 = (LogicalUnion) pruneOutput(logicalUnion, outputs, logicalUnion::pruneOutputs, pruneContext);
        List<Slot> output = logicalUnion.getOutput();
        Set<Slot> outputSet = logicalUnion2.getOutputSet();
        List list = (List) IntStream.range(0, output.size()).filter(i -> {
            return outputSet.contains(output.get(i));
        }).boxed().collect(ImmutableList.toImmutableList());
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableList.Builder builder2 = ImmutableList.builder();
        for (int i2 = 0; i2 < logicalUnion2.arity(); i2++) {
            List<SlotReference> regularChildOutput = logicalUnion2.getRegularChildOutput(i2);
            Stream stream = list.stream();
            regularChildOutput.getClass();
            List list2 = (List) stream.map((v1) -> {
                return r1.get(v1);
            }).collect(ImmutableList.toImmutableList());
            Plan doPruneChild = doPruneChild(logicalUnion2, logicalUnion2.child(i2), ImmutableSet.copyOf(list2));
            builder2.add(list2);
            builder.add(doPruneChild);
        }
        return logicalUnion2.withChildrenAndTheirOutputs(builder.build(), builder2.build());
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public Plan visitLogicalExcept(LogicalExcept logicalExcept, PruneContext pruneContext) {
        return skipPruneThisAndFirstLevelChildren(logicalExcept);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public Plan visitLogicalIntersect(LogicalIntersect logicalIntersect, PruneContext pruneContext) {
        return skipPruneThisAndFirstLevelChildren(logicalIntersect);
    }

    public Plan visitLogicalSink(LogicalSink<? extends Plan> logicalSink, PruneContext pruneContext) {
        return skipPruneThisAndFirstLevelChildren(logicalSink);
    }

    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> logicalAggregate, PruneContext pruneContext) {
        return pruneAggregate(logicalAggregate, pruneContext);
    }

    public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan> logicalRepeat, PruneContext pruneContext) {
        return pruneAggregate(logicalRepeat, pruneContext);
    }

    private Plan pruneAggregate(Aggregate aggregate, PruneContext pruneContext) {
        List<NamedExpression> outputs = aggregate.getOutputs();
        aggregate.getClass();
        Aggregate aggregate2 = (Aggregate) pruneOutput(aggregate, outputs, aggregate::pruneOutputs, pruneContext);
        return pruneChildren((Aggregate) fillUpGroupByToOutput(aggregate2.getGroupByExpressions(), aggregate2.getOutputExpressions()).map(list -> {
            return aggregate2.withAggOutput(list);
        }).orElse(aggregate2));
    }

    private Plan skipPruneThisAndFirstLevelChildren(Plan plan) {
        return pruneChildren(plan, (Set) plan.children().stream().flatMap(plan2 -> {
            return plan2.getOutputSet().stream();
        }).collect(Collectors.toSet()));
    }

    private static Optional<List<NamedExpression>> fillUpGroupByToOutput(List<Expression> list, List<NamedExpression> list2) {
        if (list2.containsAll(list)) {
            return Optional.empty();
        }
        ArrayList newArrayList = Lists.newArrayList(list2);
        newArrayList.removeAll(list);
        return Optional.of(ImmutableList.builder().addAll(list).addAll(newArrayList).build());
    }

    public static <P extends Plan> P pruneOutput(P p, List<NamedExpression> list, Function<List<NamedExpression>, P> function, PruneContext pruneContext) {
        List list2 = (List) list.stream().filter(namedExpression -> {
            return pruneContext.requiredSlots.contains(namedExpression.toSlot());
        }).collect(ImmutableList.toImmutableList());
        if (list2.isEmpty()) {
            list2 = ImmutableList.of(ExpressionUtils.selectMinimumColumn(list));
        }
        return list2.equals(list) ? p : function.apply(list2);
    }

    private <P extends Plan> P pruneChildren(P p) {
        return (P) pruneChildren(p, ImmutableSet.of());
    }

    private <P extends Plan> P pruneChildren(P p, Set<Slot> set) {
        if (p.arity() == 0) {
            return p;
        }
        ImmutableSet inputSlots = p.getInputSlots();
        ImmutableSet build = set.isEmpty() ? inputSlots : ImmutableSet.builder().addAll(set).addAll(inputSlots).build();
        ArrayList arrayList = new ArrayList();
        boolean z = false;
        for (Plan plan : p.children()) {
            Stream<Slot> stream = plan.getOutputSet().stream();
            build.getClass();
            Plan doPruneChild = doPruneChild(p, plan, (Set) stream.filter((v1) -> {
                return r1.contains(v1);
            }).collect(Collectors.toSet()));
            if (doPruneChild != plan) {
                z = true;
            }
            arrayList.add(doPruneChild);
        }
        return z ? (P) p.withChildren2(arrayList) : p;
    }

    private Plan doPruneChild(Plan plan, Plan plan2, Set<Slot> set) {
        if (plan2 instanceof LogicalCTEProducer) {
            return plan2;
        }
        boolean z = plan instanceof LogicalProject;
        Plan plan3 = (Plan) plan2.accept(this, new PruneContext(set, plan));
        if (!z && !Sets.difference(plan3.getOutputSet(), set).isEmpty()) {
            plan3 = new LogicalProject(ImmutableList.copyOf(set), plan3);
        }
        return plan3;
    }

    public Plan visitLogicalCTEProducer(LogicalCTEProducer<? extends Plan> logicalCTEProducer, PruneContext pruneContext) {
        return skipPruneThisAndFirstLevelChildren(logicalCTEProducer);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ Plan visitLogicalRepeat(LogicalRepeat logicalRepeat, Object obj) {
        return visitLogicalRepeat((LogicalRepeat<? extends Plan>) logicalRepeat, (PruneContext) obj);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ Plan visitLogicalCTEProducer(LogicalCTEProducer logicalCTEProducer, Object obj) {
        return visitLogicalCTEProducer((LogicalCTEProducer<? extends Plan>) logicalCTEProducer, (PruneContext) obj);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ Plan visitLogicalAggregate(LogicalAggregate logicalAggregate, Object obj) {
        return visitLogicalAggregate((LogicalAggregate<? extends Plan>) logicalAggregate, (PruneContext) obj);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor, org.apache.doris.nereids.trees.plans.visitor.SinkVisitor
    public /* bridge */ /* synthetic */ Object visitLogicalSink(LogicalSink logicalSink, Object obj) {
        return visitLogicalSink((LogicalSink<? extends Plan>) logicalSink, (PruneContext) obj);
    }
}
