/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.optrule;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.rules.AggregateMergeRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.kylin.guava30.shaded.common.collect.ImmutableList;
import org.apache.kylin.query.calcite.KylinRelDataTypeSystem;
import org.immutables.value.Value;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Value.Enclosing
public class ExtendedAggregateMergeRule
extends AggregateMergeRule {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(ExtendedAggregateMergeRule.class);
    public static final ExtendedAggregateMergeRule INSTANCE = new ExtendedAggregateMergeRule(AggregateMergeRule.Config.DEFAULT);

    protected ExtendedAggregateMergeRule(AggregateMergeRule.Config config) {
        super(config);
    }

    public void onMatch(RelOptRuleCall call) {
        Aggregate topAgg = (Aggregate)call.rel(0);
        Aggregate bottomAgg = (Aggregate)call.rel(1);
        if (topAgg.getGroupCount() > bottomAgg.getGroupCount()) {
            return;
        }
        ImmutableBitSet bottomGroupSet = bottomAgg.getGroupSet();
        HashMap map = new HashMap();
        bottomGroupSet.forEach(v -> map.put(map.size(), v));
        Iterator iterator = topAgg.getGroupSet().iterator();
        while (iterator.hasNext()) {
            int k = (Integer)iterator.next();
            if (map.containsKey(k)) continue;
            return;
        }
        ImmutableBitSet topGroupSet = topAgg.getGroupSet().permute(map);
        if (!bottomGroupSet.contains(topGroupSet)) {
            return;
        }
        boolean hasEmptyGroup = topAgg.getGroupSets().stream().anyMatch(ImmutableBitSet::isEmpty);
        ArrayList<AggregateCall> finalCalls = new ArrayList<AggregateCall>();
        for (AggregateCall topCall : topAgg.getAggCallList()) {
            if (!ExtendedAggregateMergeRule.isAggregateSupported(topCall) || topCall.getArgList().size() == 0) {
                return;
            }
            int bottomIndex = (Integer)topCall.getArgList().get(0) - bottomGroupSet.cardinality();
            if (bottomIndex >= bottomAgg.getAggCallList().size() || bottomIndex < 0) {
                return;
            }
            AggregateCall bottomCall = (AggregateCall)bottomAgg.getAggCallList().get(bottomIndex);
            if (!ExtendedAggregateMergeRule.isAggregateSupported(bottomCall) || bottomCall.getAggregation() == SqlStdOperatorTable.COUNT && topCall.getAggregation().getKind() != SqlKind.SUM0 && hasEmptyGroup) {
                return;
            }
            AggregateCall finalCall = ExtendedAggregateMergeRule.mergeAggregateCalls(topCall, bottomCall);
            if (finalCall == null) {
                return;
            }
            finalCalls.add(finalCall);
        }
        ImmutableList newGroupingSets = null;
        if (topAgg.getGroupType() != Aggregate.Group.SIMPLE) {
            newGroupingSets = ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute((Iterable)topAgg.getGroupSets(), map));
        }
        RelNode bottomAggInput = ExtendedAggregateMergeRule.replaceBottomAggInputIfNecessary(bottomAgg, call);
        Aggregate finalAgg = topAgg.copy(topAgg.getTraitSet(), bottomAggInput, topGroupSet, (List)newGroupingSets, finalCalls);
        call.transformTo((RelNode)finalAgg);
    }

    protected static AggregateCall sumSplitterSubstituteMerge(AggregateCall top, AggregateCall bottom) {
        SqlKind topKind = top.getAggregation().getKind();
        if (topKind == bottom.getAggregation().getKind() && (topKind == SqlKind.SUM || topKind == SqlKind.SUM0)) {
            RelDataType topType = top.getType();
            RelDataType bottomType = bottom.getType();
            RelDataType newAggCallType = topType.getSqlTypeName() == SqlTypeName.DECIMAL && bottomType.getSqlTypeName() == SqlTypeName.DECIMAL ? topType : bottomType;
            return AggregateCall.create((SqlAggFunction)bottom.getAggregation(), (boolean)bottom.isDistinct(), (boolean)bottom.isApproximate(), (boolean)false, (List)bottom.getArgList(), (int)bottom.filterArg, (ImmutableBitSet)bottom.distinctKeys, (RelCollation)bottom.getCollation(), (RelDataType)newAggCallType, (String)top.getName());
        }
        return null;
    }

    protected static RelNode replaceBottomAggInputIfNecessary(Aggregate bottomAgg, RelOptRuleCall call) {
        RelNode bottomAggInput;
        RelNode ret = bottomAggInput = bottomAgg.getInput();
        if (KylinRelDataTypeSystem.getProjectConfig().isImprovedSumDecimalPrecisionEnabled()) {
            RelSubset bottomAggInputSubSet;
            RelNode bestOrOriginal;
            RelBuilder relBuilder = call.builder();
            RexBuilder rexBuilder = call.builder().getRexBuilder();
            List sumDecimalAggCallArgs = bottomAgg.getAggCallList().stream().filter(aggCall -> aggCall.getAggregation().getKind() == SqlKind.SUM && SqlTypeUtil.isDecimal((RelDataType)aggCall.getType())).flatMap(aggCall -> aggCall.getArgList().stream()).collect(Collectors.toList());
            if (bottomAggInput instanceof RelSubset && (bestOrOriginal = (bottomAggInputSubSet = (RelSubset)bottomAggInput).getBestOrOriginal()) instanceof Project) {
                Project project = (Project)bestOrOriginal;
                relBuilder.push(project.getInput());
                ArrayList<RexNode> newProjects = new ArrayList<RexNode>();
                RelDataTypeFactory typeFactory = call.builder().getTypeFactory();
                RelDataTypeSystem typeSystem = typeFactory.getTypeSystem();
                for (int i = 0; i < project.getProjects().size(); ++i) {
                    RexNode rex = (RexNode)project.getProjects().get(i);
                    if (sumDecimalAggCallArgs.contains(i)) {
                        RelDataType newType = typeSystem.deriveSumType(typeFactory, rex.getType());
                        rex = rexBuilder.makeCast(newType, rex);
                    }
                    newProjects.add(rex);
                }
                ret = relBuilder.project(newProjects).build();
            }
        }
        return ret;
    }

    private static boolean isAggregateSupported(AggregateCall aggCall) {
        if (aggCall.isDistinct() || aggCall.hasFilter() || aggCall.isApproximate() || aggCall.getArgList().size() > 1) {
            return false;
        }
        return aggCall.getAggregation().maybeUnwrap(SqlSplittableAggFunction.class).isPresent();
    }

    private static AggregateCall mergeAggregateCalls(AggregateCall topCall, AggregateCall bottomCall) {
        SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)bottomCall.getAggregation().unwrapOrThrow(SqlSplittableAggFunction.class);
        return splitter instanceof SqlSplittableAggFunction.SumSplitter && KylinRelDataTypeSystem.getProjectConfig().isImprovedSumDecimalPrecisionEnabled() ? ExtendedAggregateMergeRule.sumSplitterSubstituteMerge(topCall, bottomCall) : splitter.merge(topCall, bottomCall);
    }
}

