package org.apache.doris.nereids.rules.analysis;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import org.apache.doris.nereids.trees.expressions.typecoercion.TypeCheckResult;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;

/* loaded from: input_file:org/apache/doris/nereids/rules/analysis/CheckAnalysis.class */
public class CheckAnalysis implements AnalysisRuleFactory {
    private static final Map<Class<? extends LogicalPlan>, Set<Class<? extends Expression>>> UNEXPECTED_EXPRESSION_TYPE_MAP = ImmutableMap.builder().put(LogicalAggregate.class, ImmutableSet.of(TableGeneratingFunction.class)).put(LogicalFilter.class, ImmutableSet.of(AggregateFunction.class, GroupingScalarFunction.class, TableGeneratingFunction.class, WindowExpression.class)).put(LogicalGenerate.class, ImmutableSet.of(AggregateFunction.class, GroupingScalarFunction.class, WindowExpression.class)).put(LogicalHaving.class, ImmutableSet.of(TableGeneratingFunction.class, WindowExpression.class)).put(LogicalJoin.class, ImmutableSet.of(AggregateFunction.class, GroupingScalarFunction.class, TableGeneratingFunction.class, WindowExpression.class)).put(LogicalOneRowRelation.class, ImmutableSet.of(AggregateFunction.class, GroupingScalarFunction.class, SlotReference.class, TableGeneratingFunction.class, WindowExpression.class)).put(LogicalProject.class, ImmutableSet.of(TableGeneratingFunction.class)).put(LogicalSort.class, ImmutableSet.of(AggregateFunction.class, GroupingScalarFunction.class, TableGeneratingFunction.class, WindowExpression.class)).put(LogicalWindow.class, ImmutableSet.of(GroupingScalarFunction.class, TableGeneratingFunction.class)).build();

    @Override // org.apache.doris.nereids.rules.RuleFactory
    public List<Rule> buildRules() {
        return ImmutableList.of(RuleType.CHECK_ANALYSIS.build(any().then(plan -> {
            checkExpressionInputTypes(plan);
            checkUnexpectedExpressions(plan);
            return null;
        })), RuleType.CHECK_AGGREGATE_ANALYSIS.build(logicalAggregate().then(logicalAggregate -> {
            checkAggregate(logicalAggregate);
            return logicalAggregate;
        })));
    }

    private void checkUnexpectedExpressions(Plan plan) {
        Set<Class<? extends Expression>> orDefault = UNEXPECTED_EXPRESSION_TYPE_MAP.getOrDefault(plan.getClass(), Collections.emptySet());
        if (orDefault.isEmpty()) {
            return;
        }
        plan.getExpressions().forEach(expression -> {
            expression.foreachUp(treeNode -> {
                Iterator it = orDefault.iterator();
                while (it.hasNext()) {
                    Class cls = (Class) it.next();
                    if (cls.isInstance(treeNode)) {
                        throw new AnalysisException(plan.getType() + " can not contains " + cls.getSimpleName() + " expression: " + ((Expression) treeNode).toSql());
                    }
                }
            });
        });
    }

    private void checkExpressionInputTypes(Plan plan) {
        Optional findFirst = plan.getExpressions().stream().map((v0) -> {
            return v0.checkInputDataTypes();
        }).filter((v0) -> {
            return v0.failed();
        }).findFirst();
        if (findFirst.isPresent()) {
            throw new AnalysisException(((TypeCheckResult) findFirst.get()).getMessage());
        }
    }

    private void checkAggregate(LogicalAggregate<? extends Plan> logicalAggregate) {
        Set<AggregateFunction> aggregateFunctions = logicalAggregate.getAggregateFunctions();
        boolean z = false;
        for (AggregateFunction aggregateFunction : aggregateFunctions) {
            if (aggregateFunction.isDistinct() && aggregateFunction.arity() > 1) {
                int i = 1;
                while (true) {
                    if (i >= aggregateFunction.arity()) {
                        break;
                    }
                    if (!aggregateFunction.child(i).getInputSlots().isEmpty()) {
                        z = true;
                        break;
                    }
                    i++;
                }
                if (z) {
                    break;
                }
            }
        }
        long count = aggregateFunctions.stream().filter((v0) -> {
            return v0.isDistinct();
        }).count();
        if (z && count > 1) {
            throw new AnalysisException("The query contains multi count distinct or sum distinct, each can't have multi columns");
        }
        Optional<Expression> findFirst = logicalAggregate.getGroupByExpressions().stream().filter(expression -> {
            return expression.containsType(AggregateFunction.class);
        }).findFirst();
        if (findFirst.isPresent()) {
            throw new AnalysisException("GROUP BY expression must not contain aggregate functions: " + findFirst.get().toSql());
        }
    }
}
