package org.apache.doris.nereids.rules.rewrite;

import java.util.Set;
import java.util.stream.Collectors;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Filter;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/InferAggNotNull.class */
public class InferAggNotNull extends OneRewriteRuleFactory {
    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalAggregate().when(logicalAggregate -> {
            return logicalAggregate.getGroupByExpressions().size() == 0;
        }).when(logicalAggregate2 -> {
            return logicalAggregate2.getAggregateFunctions().size() == 1;
        }).when(logicalAggregate3 -> {
            Set<AggregateFunction> aggregateFunctions = logicalAggregate3.getAggregateFunctions();
            return aggregateFunctions.stream().allMatch(aggregateFunction -> {
                return aggregateFunction instanceof Count;
            }) || aggregateFunctions.stream().allMatch(aggregateFunction2 -> {
                return aggregateFunction2 instanceof Avg;
            }) || aggregateFunctions.stream().allMatch(aggregateFunction3 -> {
                return aggregateFunction3 instanceof Sum;
            }) || aggregateFunctions.stream().allMatch(aggregateFunction4 -> {
                return aggregateFunction4 instanceof Max;
            }) || aggregateFunctions.stream().allMatch(aggregateFunction5 -> {
                return aggregateFunction5 instanceof Min;
            });
        }).thenApply(matchingContext -> {
            LogicalAggregate logicalAggregate4 = (LogicalAggregate) matchingContext.root;
            Set<Expression> inferNotNull = ExpressionUtils.inferNotNull((Set) logicalAggregate4.getAggregateFunctions().stream().flatMap(aggregateFunction -> {
                return aggregateFunction.children().stream();
            }).collect(Collectors.toSet()), matchingContext.cascadesContext);
            if (inferNotNull.size() == 0) {
                return null;
            }
            if ((logicalAggregate4.child() instanceof Filter) && inferNotNull.equals(((Filter) logicalAggregate4.child()).getConjuncts())) {
                return null;
            }
            return (Plan) logicalAggregate4.withChildren(PlanUtils.filter(inferNotNull, (Plan) logicalAggregate4.child()).get());
        }).toRule(RuleType.INFER_AGG_NOT_NULL);
    }
}
