package org.apache.doris.nereids.rules.analysis;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Streams;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.doris.analysis.SetUserPropertyVar;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.util.ExpressionUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.class */
public class FillUpMissingSlots implements AnalysisRuleFactory {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/analysis/FillUpMissingSlots$PlanGenerator.class */
    public interface PlanGenerator {
        Plan apply(Resolver resolver, Aggregate aggregate);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/analysis/FillUpMissingSlots$Resolver.class */
    public static class Resolver {
        private final List<NamedExpression> outputExpressions;
        private final List<Expression> groupByExpressions;
        private final Map<Expression, Slot> substitution = Maps.newHashMap();
        private final List<NamedExpression> newOutputSlots = Lists.newArrayList();

        Resolver(Aggregate aggregate) {
            this.outputExpressions = aggregate.getOutputExpressions();
            this.groupByExpressions = aggregate.getGroupByExpressions();
        }

        public void resolve(Expression expression) {
            Pair<Optional<Expression>, Boolean> lookUp = lookUp(expression);
            Optional optional = (Optional) lookUp.first;
            boolean booleanValue = ((Boolean) lookUp.second).booleanValue();
            if (optional.isPresent()) {
                if (!(optional.get() instanceof NamedExpression)) {
                    generateAliasForNewOutputSlots(expression);
                    return;
                }
                this.substitution.put(expression, ((NamedExpression) optional.get()).toSlot());
                if (booleanValue) {
                    return;
                }
                this.newOutputSlots.add(((NamedExpression) optional.get()).toSlot());
                return;
            }
            if (expression instanceof SlotReference) {
                throw new AnalysisException(expression.toSql() + " in having clause should be grouped by.");
            }
            if (expression instanceof AggregateFunction) {
                if (checkWhetherNestedAggregateFunctionsExist((AggregateFunction) expression)) {
                    throw new AnalysisException("Aggregate functions in having clause can't be nested: " + expression.toSql() + SetUserPropertyVar.DOT_SEPARATOR);
                }
                generateAliasForNewOutputSlots(expression);
            } else {
                Iterator<Expression> it = expression.children().iterator();
                while (it.hasNext()) {
                    resolve(it.next());
                }
            }
        }

        private Pair<Optional<Expression>, Boolean> lookUp(Expression expression) {
            Optional findFirst = this.outputExpressions.stream().filter(namedExpression -> {
                return isEquivalent(namedExpression, expression);
            }).map(namedExpression2 -> {
                return namedExpression2;
            }).findFirst();
            return findFirst.isPresent() ? Pair.of(findFirst, true) : Pair.of(this.groupByExpressions.stream().filter(expression2 -> {
                return isEquivalent(expression2, expression);
            }).findFirst(), false);
        }

        private boolean isEquivalent(Expression expression, Expression expression2) {
            if (expression.equals(expression2)) {
                return true;
            }
            if (!(expression instanceof Alias)) {
                return false;
            }
            Alias alias = (Alias) expression;
            return alias.toSlot().equals(expression2) || alias.child().equals(expression2);
        }

        private boolean checkWhetherNestedAggregateFunctionsExist(AggregateFunction aggregateFunction) {
            return aggregateFunction.children().stream().anyMatch(expression -> {
                Class<AggregateFunction> cls = AggregateFunction.class;
                AggregateFunction.class.getClass();
                return expression.anyMatch((v1) -> {
                    return r1.isInstance(v1);
                });
            });
        }

        private void generateAliasForNewOutputSlots(Expression expression) {
            Alias alias = new Alias(expression);
            this.newOutputSlots.add(alias);
            this.substitution.put(expression, alias.toSlot());
        }

        public Map<Expression, Slot> getSubstitution() {
            return this.substitution;
        }

        public List<NamedExpression> getNewOutputSlots() {
            return this.newOutputSlots;
        }
    }

    @Override // org.apache.doris.nereids.rules.RuleFactory
    public List<Rule> buildRules() {
        return ImmutableList.of(RuleType.FILL_UP_SORT_PROJECT.build(logicalSort(logicalProject()).then(logicalSort -> {
            LogicalProject logicalProject = (LogicalProject) logicalSort.child();
            Set<Slot> outputSet = logicalProject.getOutputSet();
            Set set = (Set) logicalSort.getOrderKeys().stream().map((v0) -> {
                return v0.getExpr();
            }).map((v0) -> {
                return v0.getInputSlots();
            }).flatMap((v0) -> {
                return v0.stream();
            }).filter(slot -> {
                return !outputSet.contains(slot);
            }).collect(Collectors.toSet());
            if (set.size() == 0) {
                return null;
            }
            return new LogicalProject(ImmutableList.copyOf(logicalProject.getOutput()), (Plan) logicalSort.withChildren(new LogicalProject(ImmutableList.builder().addAll(logicalProject.getProjects()).addAll(set).build(), (Plan) logicalProject.child())));
        })), RuleType.FILL_UP_SORT_AGGREGATE.build(logicalSort(aggregate()).when(this::checkSort).then(logicalSort2 -> {
            Aggregate<? extends Plan> aggregate = (Aggregate) logicalSort2.child();
            Resolver resolver = new Resolver(aggregate);
            List<? extends Expression> expressions = logicalSort2.getExpressions();
            resolver.getClass();
            expressions.forEach(resolver::resolve);
            return createPlan(resolver, aggregate, (resolver2, aggregate2) -> {
                List list = (List) logicalSort2.getOrderKeys().stream().map(orderKey -> {
                    return new OrderKey(ExpressionUtils.replace(orderKey.getExpr(), resolver2.getSubstitution()), orderKey.isAsc(), orderKey.isNullFirst());
                }).collect(ImmutableList.toImmutableList());
                boolean equals = list.equals(logicalSort2.getOrderKeys());
                if (equals && aggregate2.equals(aggregate)) {
                    return null;
                }
                return equals ? (Plan) logicalSort2.withChildren(aggregate2) : new LogicalSort(list, aggregate2);
            });
        })), RuleType.FILL_UP_SORT_HAVING_AGGREGATE.build(logicalSort(logicalHaving(aggregate())).when(this::checkSort).then(logicalSort3 -> {
            Aggregate<? extends Plan> aggregate = (Aggregate) ((LogicalHaving) logicalSort3.child()).child();
            Resolver resolver = new Resolver(aggregate);
            List<? extends Expression> expressions = logicalSort3.getExpressions();
            resolver.getClass();
            expressions.forEach(resolver::resolve);
            return createPlan(resolver, aggregate, (resolver2, aggregate2) -> {
                List list = (List) logicalSort3.getOrderKeys().stream().map(orderKey -> {
                    return orderKey.withExpression(ExpressionUtils.replace(orderKey.getExpr(), resolver2.getSubstitution()));
                }).collect(ImmutableList.toImmutableList());
                boolean equals = list.equals(logicalSort3.getOrderKeys());
                if (equals && aggregate2.equals(aggregate)) {
                    return null;
                }
                return equals ? (Plan) logicalSort3.withChildren((Plan) ((LogicalHaving) logicalSort3.child()).withChildren(aggregate2)) : new LogicalSort(list, (Plan) ((LogicalHaving) logicalSort3.child()).withChildren(aggregate2));
            });
        })), RuleType.FILL_UP_HAVING_AGGREGATE.build(logicalHaving(aggregate()).then(logicalHaving -> {
            Aggregate<? extends Plan> aggregate = (Aggregate) logicalHaving.child();
            Resolver resolver = new Resolver(aggregate);
            Set<Expression> conjuncts = logicalHaving.getConjuncts();
            resolver.getClass();
            conjuncts.forEach(resolver::resolve);
            return createPlan(resolver, aggregate, (resolver2, aggregate2) -> {
                Set<Expression> replace = ExpressionUtils.replace(logicalHaving.getConjuncts(), resolver2.getSubstitution());
                boolean equals = replace.equals(logicalHaving.getConjuncts());
                if (equals && aggregate2.equals(aggregate)) {
                    return null;
                }
                return equals ? (Plan) logicalHaving.withChildren(aggregate2) : new LogicalHaving(replace, aggregate2);
            });
        })), RuleType.FILL_UP_HAVING_PROJECT.build(logicalHaving().then(logicalHaving2 -> {
            return new LogicalFilter(logicalHaving2.getConjuncts(), (Plan) logicalHaving2.child());
        })));
    }

    private Plan createPlan(Resolver resolver, Aggregate<? extends Plan> aggregate, PlanGenerator planGenerator) {
        Plan apply = planGenerator.apply(resolver, resolver.getNewOutputSlots().isEmpty() ? aggregate : aggregate.withAggOutput((List) Streams.concat(new Stream[]{aggregate.getOutputExpressions().stream(), resolver.getNewOutputSlots().stream()}).collect(ImmutableList.toImmutableList())));
        if (apply == null) {
            return null;
        }
        return new LogicalProject((List) aggregate.getOutputExpressions().stream().map((v0) -> {
            return v0.toSlot();
        }).collect(ImmutableList.toImmutableList()), apply);
    }

    private boolean checkSort(LogicalSort<? extends Plan> logicalSort) {
        return logicalSort.getOrderKeys().stream().map((v0) -> {
            return v0.getExpr();
        }).map((v0) -> {
            return v0.getInputSlots();
        }).flatMap((v0) -> {
            return v0.stream();
        }).anyMatch(slot -> {
            return !((Plan) logicalSort.child()).getOutputSet().contains(slot);
        }) || logicalSort.getOrderKeys().stream().map((v0) -> {
            return v0.getExpr();
        }).anyMatch(expression -> {
            return expression.containsType(AggregateFunction.class);
        });
    }
}
