/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlJsonArrayAggAggFunction;
import org.apache.calcite.sql.fun.SqlJsonObjectAggAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableWrapJsonAggFunctionArgumentsRule;
import org.immutables.value.Value;

@Internal
@Value.Enclosing
public class WrapJsonAggFunctionArgumentsRule
extends RelRule<Config> {
    public static final RelOptRule INSTANCE = new WrapJsonAggFunctionArgumentsRule(Config.DEFAULT);
    private static final RelHint MARKER_HINT = RelHint.builder("JSON_AGGREGATE_WRAPPED").build();

    public WrapJsonAggFunctionArgumentsRule(Config config) {
        super(config);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        LogicalAggregate aggregate = (LogicalAggregate)call.rel(0);
        RelNode aggInput = aggregate.getInput();
        RelBuilder relBuilder = call.builder().push(aggInput);
        LogicalAggregate wrappedAggregate = this.wrapJsonAggregate(aggregate, relBuilder);
        call.transformTo(wrappedAggregate.withHints(Collections.singletonList(MARKER_HINT)));
    }

    private LogicalAggregate wrapJsonAggregate(LogicalAggregate aggregate, RelBuilder relBuilder) {
        int inputCount = aggregate.getInput().getRowType().getFieldCount();
        ArrayList<AggregateCall> aggCallList = new ArrayList<AggregateCall>(aggregate.getAggCallList());
        HashMap<Integer, Integer> wrapIndicesMap = new HashMap<Integer, Integer>();
        for (int i = 0; i < aggCallList.size(); ++i) {
            int valueIndex;
            AggregateCall currentCall = (AggregateCall)aggCallList.get(i);
            if (currentCall.getAggregation() instanceof SqlJsonObjectAggAggFunction) {
                valueIndex = currentCall.getArgList().get(1);
                wrapIndicesMap.put(i, valueIndex);
                continue;
            }
            if (!(currentCall.getAggregation() instanceof SqlJsonArrayAggAggFunction)) continue;
            valueIndex = currentCall.getArgList().get(0);
            wrapIndicesMap.put(i, valueIndex);
        }
        HashMap<Integer, Integer> valueIndicesAfterProjection = new HashMap<Integer, Integer>();
        this.addProjections(aggregate.getCluster(), relBuilder, wrapIndicesMap.values().stream().distinct().sorted().collect(Collectors.toList()), inputCount, valueIndicesAfterProjection);
        ArrayList<AggregateCall> newWrappedArgCallList = new ArrayList<AggregateCall>(aggCallList);
        int newInputCount = inputCount + valueIndicesAfterProjection.size();
        for (Integer jsonAggCallIndex : wrapIndicesMap.keySet()) {
            Mapping argsMapping = Mappings.create(MappingType.BIJECTION, newInputCount, newInputCount);
            Integer valueIndex = (Integer)wrapIndicesMap.get(jsonAggCallIndex);
            argsMapping.set(valueIndex, (Integer)valueIndicesAfterProjection.get(valueIndex));
            AggregateCall newAggregateCall = ((AggregateCall)newWrappedArgCallList.get(jsonAggCallIndex)).transform(argsMapping);
            newWrappedArgCallList.set(jsonAggCallIndex, newAggregateCall);
        }
        return aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.getGroupSet(), aggregate.getGroupSets(), newWrappedArgCallList);
    }

    private void addProjections(RelOptCluster cluster, RelBuilder relBuilder, List<Integer> affectedArgs, int inputCount, Map<Integer, Integer> valueIndicesAfterProjection) {
        BridgingSqlFunction operandToStringOperator = BridgingSqlFunction.of(cluster, BuiltInFunctionDefinitions.JSON_STRING);
        ArrayList<RexNode> projects = new ArrayList<RexNode>();
        for (Integer argIdx : affectedArgs) {
            valueIndicesAfterProjection.put(argIdx, inputCount + projects.size());
            projects.add(relBuilder.call((SqlOperator)operandToStringOperator, relBuilder.field(argIdx)));
        }
        relBuilder.projectPlus(projects);
    }

    private static boolean isJsonAggregation(AggregateCall aggCall) {
        SqlAggFunction aggregation = aggCall.getAggregation();
        return aggregation instanceof SqlJsonObjectAggAggFunction || aggregation instanceof SqlJsonArrayAggAggFunction;
    }

    @Value.Immutable(singleton=false)
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableWrapJsonAggFunctionArgumentsRule.Config.builder().build().as(Config.class).onJsonAggregateFunctions();

        @Override
        default public RelOptRule toRule() {
            return new WrapJsonAggFunctionArgumentsRule(this);
        }

        default public Config onJsonAggregateFunctions() {
            Predicate<LogicalAggregate> jsonAggPredicate = aggregate -> aggregate.getAggCallList().stream().anyMatch(x$0 -> WrapJsonAggFunctionArgumentsRule.isJsonAggregation(x$0));
            RelRule.OperandTransform aggTransform = operandBuilder -> operandBuilder.operand(LogicalAggregate.class).predicate(jsonAggPredicate).anyInputs();
            return this.withOperandSupplier(aggTransform).as(Config.class);
        }
    }
}

