package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Stream;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/CountDistinctRewrite.class */
public class CountDistinctRewrite extends OneRewriteRuleFactory {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/CountDistinctRewrite$CountDistinctRewriter.class */
    public static class CountDistinctRewriter extends DefaultExpressionRewriter<Void> {
        private static final CountDistinctRewriter INSTANCE = new CountDistinctRewriter();

        private CountDistinctRewriter() {
        }

        public static Expression rewrite(Expression expression) {
            return (Expression) expression.accept(INSTANCE, null);
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.AggregateFunctionVisitor
        public Expression visitCount(Count count, Void r6) {
            if (count.isDistinct() && count.arity() == 1) {
                Expression child = count.child(0);
                if (child.getDataType().isBitmapType()) {
                    return new BitmapUnionCount(child);
                }
                if (child.getDataType().isHllType()) {
                    return new HllUnionAgg(child);
                }
            }
            return count;
        }
    }

    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalAggregate().then(logicalAggregate -> {
            Stream<R> map = logicalAggregate.getOutputExpressions().stream().map((v0) -> {
                return CountDistinctRewriter.rewrite(v0);
            });
            Class<NamedExpression> cls = NamedExpression.class;
            NamedExpression.class.getClass();
            return logicalAggregate.withAggOutput((List<NamedExpression>) map.map((v1) -> {
                return r1.cast(v1);
            }).collect(ImmutableList.toImmutableList()));
        }).toRule(RuleType.COUNT_DISTINCT_REWRITE);
    }
}
