/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch.util;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AggregatedBlockingInputInfo {
    private static final Logger LOG = LoggerFactory.getLogger(AggregatedBlockingInputInfo.class);
    private final int maxPartitionNum;
    private final long skewedThreshold;
    private final long targetSize;
    private final boolean intraInputKeyCorrelated;
    private final Map<Integer, long[]> subpartitionBytesByPartition;
    private final long[] aggregatedSubpartitionBytes;

    private AggregatedBlockingInputInfo(long targetSize, long skewedThreshold, int maxPartitionNum, boolean intraInputKeyCorrelated, Map<Integer, long[]> subpartitionBytesByPartition, long[] aggregatedSubpartitionBytes) {
        this.maxPartitionNum = maxPartitionNum;
        this.skewedThreshold = skewedThreshold;
        this.targetSize = targetSize;
        this.intraInputKeyCorrelated = intraInputKeyCorrelated;
        this.subpartitionBytesByPartition = (Map)Preconditions.checkNotNull(subpartitionBytesByPartition);
        this.aggregatedSubpartitionBytes = (long[])Preconditions.checkNotNull((Object)aggregatedSubpartitionBytes);
    }

    public int getMaxPartitionNum() {
        return this.maxPartitionNum;
    }

    public long getTargetSize() {
        return this.targetSize;
    }

    public Map<Integer, long[]> getSubpartitionBytesByPartition() {
        return Collections.unmodifiableMap(this.subpartitionBytesByPartition);
    }

    public long getAggregatedSubpartitionBytes(int subpartitionIndex) {
        return this.aggregatedSubpartitionBytes[subpartitionIndex];
    }

    public boolean isSplittable() {
        return !this.intraInputKeyCorrelated && !this.subpartitionBytesByPartition.isEmpty();
    }

    public boolean isSkewedSubpartition(int subpartitionIndex) {
        return this.aggregatedSubpartitionBytes[subpartitionIndex] > this.skewedThreshold;
    }

    public int getNumSubpartitions() {
        return this.aggregatedSubpartitionBytes.length;
    }

    private static long[] computeAggregatedSubpartitionBytes(List<BlockingInputInfo> inputInfos, int subpartitionNum) {
        long[] aggregatedSubpartitionBytes = new long[subpartitionNum];
        for (BlockingInputInfo inputInfo : inputInfos) {
            List<Long> subpartitionBytes = inputInfo.getAggregatedSubpartitionBytes();
            for (int i = 0; i < subpartitionBytes.size(); ++i) {
                int n = i;
                aggregatedSubpartitionBytes[n] = aggregatedSubpartitionBytes[n] + subpartitionBytes.get(i);
            }
        }
        return aggregatedSubpartitionBytes;
    }

    private static Map<Integer, long[]> computeSubpartitionBytesByPartitionIndex(List<BlockingInputInfo> inputInfos, int subpartitionNum) {
        if (!VertexParallelismAndInputInfosDeciderUtils.hasSameNumPartitions(inputInfos)) {
            LOG.warn("Input infos have different num partitions, skip calculate SubpartitionBytesByPartitionIndex");
            return Collections.emptyMap();
        }
        HashMap<Integer, long[]> subpartitionBytesByPartitionIndex = new HashMap<Integer, long[]>();
        for (BlockingInputInfo inputInfo : inputInfos) {
            inputInfo.getSubpartitionBytesByPartitionIndex().forEach((partitionIdx, subPartitionBytes) -> {
                long[] subpartitionBytes = subpartitionBytesByPartitionIndex.computeIfAbsent((Integer)partitionIdx, v -> new long[subpartitionNum]);
                for (int i = 0; i < subpartitionNum; ++i) {
                    int n = i;
                    subpartitionBytes[n] = subpartitionBytes[n] + subPartitionBytes[i];
                }
            });
        }
        return subpartitionBytesByPartitionIndex;
    }

    public static AggregatedBlockingInputInfo createAggregatedBlockingInputInfo(long defaultSkewedThreshold, double skewedFactor, long dataVolumePerTask, List<BlockingInputInfo> inputInfos) {
        int subpartitionNum = VertexParallelismAndInputInfosDeciderUtils.checkAndGetSubpartitionNum(inputInfos);
        long[] aggregatedSubpartitionBytes = AggregatedBlockingInputInfo.computeAggregatedSubpartitionBytes(inputInfos, subpartitionNum);
        long skewedThreshold = VertexParallelismAndInputInfosDeciderUtils.computeSkewThreshold(VertexParallelismAndInputInfosDeciderUtils.median(aggregatedSubpartitionBytes), skewedFactor, defaultSkewedThreshold);
        long targetSize = VertexParallelismAndInputInfosDeciderUtils.computeTargetSize(aggregatedSubpartitionBytes, skewedThreshold, dataVolumePerTask);
        boolean isIntraInputKeyCorrelated = VertexParallelismAndInputInfosDeciderUtils.checkAndGetIntraCorrelation(inputInfos);
        Map<Object, Object> subpartitionBytesByPartitionIndex = isIntraInputKeyCorrelated ? new HashMap() : AggregatedBlockingInputInfo.computeSubpartitionBytesByPartitionIndex(inputInfos, subpartitionNum);
        return new AggregatedBlockingInputInfo(targetSize, skewedThreshold, VertexParallelismAndInputInfosDeciderUtils.getMaxNumPartitions(inputInfos), isIntraInputKeyCorrelated, subpartitionBytesByPartitionIndex, aggregatedSubpartitionBytes);
    }
}

