package org.apache.doris.nereids.rules.analysis;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/analysis/NormalizeAggregate.class */
public class NormalizeAggregate extends OneRewriteRuleFactory implements NormalizeToSlot {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/analysis/NormalizeAggregate$CollectNonWindowedAggFuncs.class */
    public static class CollectNonWindowedAggFuncs extends DefaultExpressionVisitor<Void, List<AggregateFunction>> {
        private static final CollectNonWindowedAggFuncs INSTANCE = new CollectNonWindowedAggFuncs();

        private CollectNonWindowedAggFuncs() {
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
        public Void visitWindow(WindowExpression windowExpression, List<AggregateFunction> list) {
            Iterator<Expression> it = windowExpression.getExpressionsInWindowSpec().iterator();
            while (it.hasNext()) {
                it.next().accept(this, list);
            }
            return null;
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor, org.apache.doris.nereids.trees.expressions.visitor.AggregateFunctionVisitor
        public Void visitAggregateFunction(AggregateFunction aggregateFunction, List<AggregateFunction> list) {
            list.add(aggregateFunction);
            return null;
        }
    }

    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalAggregate().whenNot((v0) -> {
            return v0.isNormalized();
        }).then(logicalAggregate -> {
            Alias alias;
            List<NamedExpression> outputExpressions = logicalAggregate.getOutputExpressions();
            Class<Alias> cls = Alias.class;
            Alias.class.getClass();
            Set mutableCollect = ExpressionUtils.mutableCollect(outputExpressions, (v1) -> {
                return r1.isInstance(v1);
            });
            ArrayList newArrayList = Lists.newArrayList();
            outputExpressions.forEach(namedExpression -> {
            });
            List list = (List) newArrayList.stream().filter(aggregateFunction -> {
                return !aggregateFunction.isDistinct();
            }).collect(Collectors.toList());
            Class<SubqueryExpr> cls2 = SubqueryExpr.class;
            SubqueryExpr.class.getClass();
            Set mutableCollect2 = ExpressionUtils.mutableCollect(list, (v1) -> {
                return r1.isInstance(v1);
            });
            ImmutableSet copyOf = ImmutableSet.copyOf(logicalAggregate.getGroupByExpressions());
            NormalizeToSlot.NormalizeToSlotContext buildContext = NormalizeToSlot.NormalizeToSlotContext.buildContext(mutableCollect, Sets.union(copyOf, mutableCollect2));
            Set<NamedExpression> pushDownToNamedExpression = buildContext.pushDownToNamedExpression(Sets.union(copyOf, mutableCollect2));
            List<AggregateFunction> normalizeToUseSlotRef = buildContext.normalizeToUseSlotRef(newArrayList);
            HashSet newHashSet = Sets.newHashSet(pushDownToNamedExpression);
            ArrayList<AggregateFunction> newArrayList2 = Lists.newArrayList();
            ArrayList newArrayList3 = Lists.newArrayList();
            for (AggregateFunction aggregateFunction2 : normalizeToUseSlotRef) {
                if (aggregateFunction2.isDistinct()) {
                    newArrayList2.add(aggregateFunction2);
                } else {
                    newArrayList3.add(aggregateFunction2);
                }
            }
            if (!newArrayList2.isEmpty()) {
                ArrayList newArrayList4 = Lists.newArrayList();
                HashMap newHashMap = Maps.newHashMap();
                HashMap newHashMap2 = Maps.newHashMap();
                for (AggregateFunction aggregateFunction3 : newArrayList2) {
                    ArrayList newArrayList5 = Lists.newArrayList();
                    for (Expression expression : aggregateFunction3.children()) {
                        if ((expression instanceof SlotReference) || (expression instanceof Literal)) {
                            newArrayList5.add(expression);
                        } else {
                            if (newHashMap2.containsKey(expression)) {
                                alias = (NamedExpression) newHashMap2.get(expression);
                            } else {
                                alias = new Alias(expression);
                                newHashMap2.put(expression, alias);
                            }
                            newHashSet.add(alias);
                            newArrayList5.add(alias.toSlot());
                        }
                    }
                    AggregateFunction withChildren2 = aggregateFunction3.withChildren2((List<Expression>) newArrayList5);
                    newHashMap.put(aggregateFunction3, withChildren2);
                    newArrayList4.add(withChildren2);
                }
                Stream map = outputExpressions.stream().map(namedExpression2 -> {
                    return ExpressionUtils.replace(namedExpression2, (Map<? extends Expression, ? extends Expression>) newHashMap);
                });
                Class<NamedExpression> cls3 = NamedExpression.class;
                NamedExpression.class.getClass();
                outputExpressions = (List) map.map((v1) -> {
                    return r1.cast(v1);
                }).collect(Collectors.toList());
                newArrayList2 = newArrayList4;
            }
            ArrayList newArrayList6 = Lists.newArrayList(newArrayList3);
            newArrayList6.addAll(newArrayList2);
            NormalizeToSlot.NormalizeToSlotContext buildContext2 = NormalizeToSlot.NormalizeToSlotContext.buildContext(mutableCollect, newArrayList6);
            ImmutableList build = ImmutableList.builder().addAll(pushDownToNamedExpression.stream().map((v0) -> {
                return v0.toSlot();
            }).iterator()).addAll(buildContext2.pushDownToNamedExpression(newArrayList6)).build();
            Set set = (Set) newHashSet.stream().map((v0) -> {
                return v0.toSlot();
            }).collect(Collectors.toSet());
            newHashSet.addAll((Set) newArrayList6.stream().map((v0) -> {
                return v0.getInputSlots();
            }).flatMap((v0) -> {
                return v0.stream();
            }).filter(slot -> {
                return !set.contains(slot);
            }).collect(Collectors.toSet()));
            return new LogicalProject(normalizeOutput(outputExpressions, buildContext, buildContext2), logicalAggregate.withNormalized(buildContext.normalizeToUseSlotRef((Collection) copyOf), build, !newHashSet.isEmpty() ? new LogicalProject(ImmutableList.copyOf(newHashSet), (Plan) logicalAggregate.child()) : (Plan) logicalAggregate.child()));
        }).toRule(RuleType.NORMALIZE_AGGREGATE);
    }

    private List<NamedExpression> normalizeOutput(List<NamedExpression> list, NormalizeToSlot.NormalizeToSlotContext normalizeToSlotContext, NormalizeToSlot.NormalizeToSlotContext normalizeToSlotContext2) {
        List normalizeToUseSlotRefWithoutWindowFunction = normalizeToSlotContext2.normalizeToUseSlotRefWithoutWindowFunction(normalizeToSlotContext.normalizeToUseSlotRefWithoutWindowFunction(list));
        ImmutableList.Builder builder = new ImmutableList.Builder();
        for (int i = 0; i < list.size(); i++) {
            NamedExpression namedExpression = (NamedExpression) normalizeToUseSlotRefWithoutWindowFunction.get(i);
            if ((namedExpression instanceof Alias) && (namedExpression.child(0) instanceof SlotReference)) {
                SlotReference slotReference = (SlotReference) namedExpression.child(0);
                if (slotReference.getExprId().equals(namedExpression.getExprId())) {
                    namedExpression = slotReference;
                }
            }
            if (!namedExpression.getExprId().equals(list.get(i).getExprId())) {
                namedExpression = new Alias(list.get(i).getExprId(), namedExpression, list.get(i).getName());
            }
            builder.add(namedExpression);
        }
        return builder.build();
    }
}
