package org.apache.doris.nereids.rules.analysis;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/analysis/AvgDistinctToSumDivCount.class */
public class AvgDistinctToSumDivCount extends OneRewriteRuleFactory {
    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return RuleType.AVG_DISTINCT_TO_SUM_DIV_COUNT.build(logicalAggregate().when(logicalAggregate -> {
            return logicalAggregate.getDistinctArguments().size() > 1;
        }).then(logicalAggregate2 -> {
            Map map = (Map) logicalAggregate2.getAggregateFunctions().stream().filter(aggregateFunction -> {
                return (aggregateFunction instanceof Avg) && aggregateFunction.isDistinct();
            }).collect(ImmutableMap.toImmutableMap(aggregateFunction2 -> {
                return aggregateFunction2;
            }, aggregateFunction3 -> {
                Sum sum = (Sum) TypeCoercionUtils.processBoundFunction(new Sum(true, ((Avg) aggregateFunction3).isAlwaysNullable(), ((Avg) aggregateFunction3).child()));
                Count count = (Count) TypeCoercionUtils.processBoundFunction(new Count(true, ((Avg) aggregateFunction3).child(), new Expression[0]));
                return TypeCoercionUtils.processDivide(new Divide(sum, count), sum, count);
            }));
            if (map.isEmpty()) {
                return logicalAggregate2;
            }
            return new LogicalAggregate(logicalAggregate2.getGroupByExpressions(), (List) logicalAggregate2.getOutputExpressions().stream().map(namedExpression -> {
                return (NamedExpression) ExpressionUtils.replace(namedExpression, (Map<? extends Expression, ? extends Expression>) map);
            }).collect(ImmutableList.toImmutableList()), (Plan) logicalAggregate2.child());
        }));
    }
}
