/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.ImmutableAggregateUnionTransposeRule;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlAnyValueAggFunction;
import org.apache.calcite.sql.fun.SqlBitOpAggFunction;
import org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList;
import org.apache.flink.calcite.shaded.org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;

@Value.Enclosing
public class AggregateUnionTransposeRule
extends RelRule<Config>
implements TransformationRule {
    private static final IdentityHashMap<Class<? extends SqlAggFunction>, Boolean> SUPPORTED_AGGREGATES = new IdentityHashMap();

    protected AggregateUnionTransposeRule(Config config) {
        super(config);
    }

    @Deprecated
    public AggregateUnionTransposeRule(Class<? extends Aggregate> aggregateClass, Class<? extends Union> unionClass, RelBuilderFactory relBuilderFactory) {
        this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class).withOperandFor(aggregateClass, unionClass));
    }

    @Deprecated
    public AggregateUnionTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Union> unionClass, RelFactories.SetOpFactory setOpFactory) {
        this(aggregateClass, unionClass, RelBuilder.proto(aggregateFactory, setOpFactory));
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggRel = (Aggregate)call.rel(0);
        Union union = (Union)call.rel(1);
        if (!union.all) {
            return;
        }
        int groupCount = aggRel.getGroupSet().cardinality();
        List<AggregateCall> transformedAggCalls = AggregateUnionTransposeRule.transformAggCalls(aggRel.copy(aggRel.getTraitSet(), aggRel.getInput(), aggRel.getGroupSet(), null, aggRel.getAggCallList()), groupCount, aggRel.getAggCallList());
        if (transformedAggCalls == null) {
            return;
        }
        boolean hasUniqueKeyInAllInputs = true;
        RelMetadataQuery mq = call.getMetadataQuery();
        for (RelNode relNode : union.getInputs()) {
            boolean alreadyUnique = RelMdUtil.areColumnsDefinitelyUnique(mq, relNode, aggRel.getGroupSet());
            if (alreadyUnique) continue;
            hasUniqueKeyInAllInputs = false;
            break;
        }
        if (hasUniqueKeyInAllInputs) {
            return;
        }
        RelBuilder relBuilder = call.builder();
        for (RelNode input : union.getInputs()) {
            relBuilder.push(input);
            relBuilder.aggregate(relBuilder.groupKey(aggRel.getGroupSet()), aggRel.getAggCallList());
        }
        relBuilder.union(true, union.getInputs().size());
        ImmutableBitSet immutableBitSet = aggRel.getGroupSet();
        Mapping topGroupMapping = Mappings.create(MappingType.INVERSE_SURJECTION, union.getRowType().getFieldCount(), aggRel.getGroupCount());
        for (int i = 0; i < immutableBitSet.cardinality(); ++i) {
            topGroupMapping.set(immutableBitSet.nth(i), i);
        }
        ImmutableBitSet topGroupSet = Mappings.apply(topGroupMapping, immutableBitSet);
        ImmutableList<ImmutableBitSet> topGroupSets = Mappings.apply2(topGroupMapping, aggRel.getGroupSets());
        relBuilder.aggregate(relBuilder.groupKey(topGroupSet, (Iterable<? extends ImmutableBitSet>)topGroupSets), transformedAggCalls);
        call.transformTo(relBuilder.build());
    }

    private static @Nullable List<AggregateCall> transformAggCalls(RelNode input, int groupCount, List<AggregateCall> origCalls) {
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>();
        for (Ord<AggregateCall> ord : Ord.zip(origCalls)) {
            RelDataType aggType;
            SqlAggFunction aggFun;
            AggregateCall origCall = (AggregateCall)ord.e;
            if (origCall.isDistinct() || !SUPPORTED_AGGREGATES.containsKey(origCall.getAggregation().getClass())) {
                return null;
            }
            if (origCall.getAggregation() == SqlStdOperatorTable.COUNT) {
                aggFun = SqlStdOperatorTable.SUM0;
                aggType = null;
            } else {
                aggFun = origCall.getAggregation();
                aggType = origCall.getType();
            }
            AggregateCall newCall = AggregateCall.create(aggFun, origCall.isDistinct(), origCall.isApproximate(), origCall.ignoreNulls(), ImmutableList.of(Integer.valueOf(groupCount + ord.i)), -1, origCall.distinctKeys, origCall.collation, groupCount, input, aggType, origCall.getName());
            newCalls.add(newCall);
        }
        return newCalls;
    }

    static {
        SUPPORTED_AGGREGATES.put(SqlMinMaxAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlCountAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlSumAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlSumEmptyIsZeroAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlAnyValueAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlBitOpAggFunction.class, true);
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableAggregateUnionTransposeRule.Config.of().withOperandFor(LogicalAggregate.class, LogicalUnion.class);

        @Override
        default public AggregateUnionTransposeRule toRule() {
            return new AggregateUnionTransposeRule(this);
        }

        default public Config withOperandFor(Class<? extends Aggregate> aggregateClass, Class<? extends Union> unionClass) {
            return this.withOperandSupplier(b0 -> b0.operand(aggregateClass).oneInput(b1 -> b1.operand(unionClass).anyInputs())).as(Config.class);
        }
    }
}

