package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableSet;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.class */
public class CheckMultiDistinct extends OneRewriteRuleFactory {
    private final ImmutableSet<Class<? extends AggregateFunction>> supportedFunctions = ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class);

    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalAggregate().then(logicalAggregate -> {
            return checkDistinct(logicalAggregate);
        }).toRule(RuleType.CHECK_ANALYSIS);
    }

    private LogicalAggregate checkDistinct(LogicalAggregate<? extends Plan> logicalAggregate) {
        if (logicalAggregate.getDistinctArguments().size() > 1) {
            for (AggregateFunction aggregateFunction : logicalAggregate.getAggregateFunctions()) {
                if (aggregateFunction.isDistinct() && !this.supportedFunctions.contains(aggregateFunction.getClass())) {
                    throw new AnalysisException(aggregateFunction.toString() + " can't support multi distinct.");
                }
            }
        }
        return logicalAggregate;
    }
}
