package org.apache.doris.nereids.rules.exploration;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Stream;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;

/* loaded from: input_file:org/apache/doris/nereids/rules/exploration/EagerSplit.class */
public class EagerSplit extends OneExplorationRuleFactory {
    public static final EagerSplit INSTANCE = new EagerSplit();

    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalAggregate(innerLogicalJoin()).when(logicalAggregate -> {
            return logicalAggregate.getAggregateFunctions().stream().allMatch(aggregateFunction -> {
                return (aggregateFunction instanceof Sum) && (((Sum) aggregateFunction).child() instanceof SlotReference);
            });
        }).then(logicalAggregate2 -> {
            LogicalJoin logicalJoin = (LogicalJoin) logicalAggregate2.child();
            List<Slot> output = ((GroupPlan) logicalJoin.left()).getOutput();
            List<Slot> output2 = ((GroupPlan) logicalJoin.right()).getOutput();
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            Iterator<AggregateFunction> it = logicalAggregate2.getAggregateFunctions().iterator();
            while (it.hasNext()) {
                Sum sum = (Sum) it.next();
                if (output.contains((Slot) sum.child())) {
                    arrayList.add(sum);
                } else {
                    arrayList2.add(sum);
                }
            }
            if (arrayList.size() == 0 || arrayList2.size() == 0) {
                return null;
            }
            HashSet hashSet = new HashSet();
            Stream<R> map = logicalAggregate2.getGroupByExpressions().stream().map(expression -> {
                return (Slot) expression;
            });
            output.getClass();
            Stream filter = map.filter((v1) -> {
                return r1.contains(v1);
            });
            hashSet.getClass();
            filter.forEach((v1) -> {
                r1.add(v1);
            });
            logicalJoin.getHashJoinConjuncts().forEach(expression2 -> {
                expression2.getInputSlots().forEach(slot -> {
                    if (output.contains(slot)) {
                        hashSet.add(slot);
                    }
                });
            });
            ArrayList arrayList3 = new ArrayList();
            for (int i = 0; i < arrayList.size(); i++) {
                arrayList3.add(new Alias(new Sum(((Sum) arrayList.get(i)).child()), "left_sum" + i));
            }
            Alias alias = new Alias(new Count(), "left_cnt");
            LogicalAggregate logicalAggregate2 = new LogicalAggregate(ImmutableList.copyOf(hashSet), ImmutableList.builder().addAll(hashSet).addAll(arrayList3).add(alias).build(), logicalJoin.left());
            HashSet hashSet2 = new HashSet();
            Stream<R> map2 = logicalAggregate2.getGroupByExpressions().stream().map(expression3 -> {
                return (Slot) expression3;
            });
            output2.getClass();
            Stream filter2 = map2.filter((v1) -> {
                return r1.contains(v1);
            });
            hashSet2.getClass();
            filter2.forEach((v1) -> {
                r1.add(v1);
            });
            logicalJoin.getHashJoinConjuncts().forEach(expression4 -> {
                expression4.getInputSlots().forEach(slot -> {
                    if (output2.contains(slot)) {
                        hashSet2.add(slot);
                    }
                });
            });
            ArrayList arrayList4 = new ArrayList();
            for (int i2 = 0; i2 < arrayList2.size(); i2++) {
                arrayList4.add(new Alias(new Sum(((Sum) arrayList2.get(i2)).child()), "right_sum" + i2));
            }
            Alias alias2 = new Alias(new Count(), "right_cnt");
            Plan plan = (Plan) logicalJoin.withChildren(logicalAggregate2, new LogicalAggregate(ImmutableList.copyOf(hashSet2), ImmutableList.builder().addAll(hashSet2).addAll(arrayList4).add(alias2).build(), logicalJoin.right()));
            ArrayList arrayList5 = new ArrayList();
            ArrayList arrayList6 = new ArrayList();
            ArrayList arrayList7 = new ArrayList();
            for (NamedExpression namedExpression : logicalAggregate2.getOutputExpressions()) {
                if ((namedExpression instanceof Alias) && (((Alias) namedExpression).child() instanceof Sum)) {
                    Alias alias3 = (Alias) namedExpression;
                    if (output.contains((Slot) ((Sum) alias3.child()).child())) {
                        arrayList6.add(alias3);
                    } else {
                        arrayList7.add(alias3);
                    }
                } else {
                    arrayList5.add(namedExpression);
                }
            }
            Preconditions.checkState(arrayList6.size() == arrayList3.size());
            Preconditions.checkState(arrayList7.size() == arrayList4.size());
            for (int i3 = 0; i3 < arrayList6.size(); i3++) {
                Alias alias4 = (Alias) arrayList6.get(i3);
                arrayList5.add(new Alias(alias4.getExprId(), new Sum(new Multiply(((NamedExpression) arrayList3.get(i3)).toSlot(), alias2.toSlot())), alias4.getName()));
            }
            for (int i4 = 0; i4 < arrayList7.size(); i4++) {
                Alias alias5 = (Alias) arrayList7.get(i4);
                arrayList5.add(new Alias(alias5.getExprId(), new Sum(new Multiply(((NamedExpression) arrayList4.get(i4)).toSlot(), alias.toSlot())), alias5.getName()));
            }
            return (Plan) logicalAggregate2.withAggOutput((List<NamedExpression>) arrayList5).withChildren(plan);
        }).toRule(RuleType.EAGER_SPLIT);
    }
}
