package org.apache.doris.nereids.rules.exploration;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Stream;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;

/* loaded from: input_file:org/apache/doris/nereids/rules/exploration/EagerGroupByCount.class */
public class EagerGroupByCount extends OneExplorationRuleFactory {
    public static final EagerGroupByCount INSTANCE = new EagerGroupByCount();

    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalAggregate(innerLogicalJoin()).when(logicalAggregate -> {
            return ((LogicalJoin) logicalAggregate.child()).getOtherJoinConjuncts().size() == 0;
        }).when(logicalAggregate2 -> {
            return logicalAggregate2.getAggregateFunctions().stream().allMatch(aggregateFunction -> {
                return (aggregateFunction instanceof Sum) && (((Sum) aggregateFunction).child() instanceof Slot);
            });
        }).then(logicalAggregate3 -> {
            LogicalJoin logicalJoin = (LogicalJoin) logicalAggregate3.child();
            List<Slot> output = ((GroupPlan) logicalJoin.left()).getOutput();
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            Iterator<AggregateFunction> it = logicalAggregate3.getAggregateFunctions().iterator();
            while (it.hasNext()) {
                Sum sum = (Sum) it.next();
                if (output.contains((Slot) sum.child())) {
                    arrayList.add(sum);
                } else {
                    arrayList2.add(sum);
                }
            }
            if (arrayList.size() == 0 || arrayList2.size() == 0) {
                return null;
            }
            HashSet hashSet = new HashSet();
            Stream<R> map = logicalAggregate3.getGroupByExpressions().stream().map(expression -> {
                return (Slot) expression;
            });
            output.getClass();
            Stream filter = map.filter((v1) -> {
                return r1.contains(v1);
            });
            hashSet.getClass();
            filter.forEach((v1) -> {
                r1.add(v1);
            });
            logicalJoin.getHashJoinConjuncts().forEach(expression2 -> {
                expression2.getInputSlots().forEach(slot -> {
                    if (output.contains(slot)) {
                        hashSet.add(slot);
                    }
                });
            });
            ArrayList arrayList3 = new ArrayList();
            for (int i = 0; i < arrayList.size(); i++) {
                arrayList3.add(new Alias(new Sum(((Sum) arrayList.get(i)).child()), "sum" + i));
            }
            Alias alias = new Alias(new Count(), "cnt");
            Plan plan = (Plan) logicalJoin.withChildren(new LogicalAggregate(ImmutableList.copyOf(hashSet), ImmutableList.builder().addAll(hashSet).addAll(arrayList3).add(alias).build(), logicalJoin.left()), logicalJoin.right());
            ArrayList arrayList4 = new ArrayList();
            ArrayList arrayList5 = new ArrayList();
            ArrayList<Alias> arrayList6 = new ArrayList();
            for (NamedExpression namedExpression : logicalAggregate3.getOutputExpressions()) {
                if ((namedExpression instanceof Alias) && (((Alias) namedExpression).child() instanceof Sum)) {
                    Alias alias2 = (Alias) namedExpression;
                    if (output.contains((Slot) ((Sum) alias2.child()).child())) {
                        arrayList5.add(alias2);
                    } else {
                        arrayList6.add(alias2);
                    }
                } else {
                    arrayList4.add(namedExpression);
                }
            }
            for (int i2 = 0; i2 < arrayList5.size(); i2++) {
                Alias alias3 = (Alias) arrayList5.get(i2);
                arrayList4.add(new Alias(alias3.getExprId(), new Sum(((NamedExpression) arrayList3.get(i2)).toSlot()), alias3.getName()));
            }
            for (Alias alias4 : arrayList6) {
                arrayList4.add(new Alias(alias4.getExprId(), new Sum(new Multiply((Slot) ((Sum) alias4.child()).child(), alias.toSlot())), alias4.getName()));
            }
            return (Plan) logicalAggregate3.withAggOutput((List<NamedExpression>) arrayList4).withChildren(plan);
        }).toRule(RuleType.EAGER_GROUP_BY_COUNT);
    }
}
