/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.executiongraph;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.JobException;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.executiongraph.IntermediateResultInfo;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.util.Preconditions;

public class VertexInputInfoComputationUtils {
    public static Map<IntermediateDataSetID, JobVertexInputInfo> computeVertexInputInfos(ExecutionJobVertex ejv, Function<IntermediateDataSetID, IntermediateResult> intermediateResultRetriever) throws JobException {
        Preconditions.checkState(ejv.isParallelismDecided());
        ArrayList<IntermediateResultWrapper> intermediateResultInfos = new ArrayList<IntermediateResultWrapper>();
        for (JobEdge edge : ejv.getJobVertex().getInputs()) {
            IntermediateResult ires = intermediateResultRetriever.apply(edge.getSourceId());
            if (ires == null) {
                throw new JobException("Cannot connect this job graph to the previous graph. No previous intermediate result found for ID " + String.valueOf(edge.getSourceId()));
            }
            intermediateResultInfos.add(new IntermediateResultWrapper(ires));
        }
        return VertexInputInfoComputationUtils.computeVertexInputInfos(ejv.getParallelism(), intermediateResultInfos, ejv.getGraph().isDynamic());
    }

    public static Map<IntermediateDataSetID, JobVertexInputInfo> computeVertexInputInfos(int parallelism, List<? extends IntermediateResultInfo> inputs, boolean isDynamicGraph) {
        Preconditions.checkArgument(parallelism > 0);
        LinkedHashMap<IntermediateDataSetID, JobVertexInputInfo> jobVertexInputInfos = new LinkedHashMap<IntermediateDataSetID, JobVertexInputInfo>();
        for (IntermediateResultInfo intermediateResultInfo : inputs) {
            int sourceParallelism = intermediateResultInfo.getNumPartitions();
            if (intermediateResultInfo.isPointwise()) {
                jobVertexInputInfos.putIfAbsent(intermediateResultInfo.getResultId(), VertexInputInfoComputationUtils.computeVertexInputInfoForPointwise(sourceParallelism, parallelism, intermediateResultInfo::getNumSubpartitions, isDynamicGraph));
                continue;
            }
            jobVertexInputInfos.putIfAbsent(intermediateResultInfo.getResultId(), VertexInputInfoComputationUtils.computeVertexInputInfoForAllToAll(sourceParallelism, parallelism, intermediateResultInfo::getNumSubpartitions, isDynamicGraph, intermediateResultInfo.isBroadcast(), intermediateResultInfo.isSingleSubpartitionContainsAllData()));
        }
        return jobVertexInputInfos;
    }

    public static JobVertexInputInfo computeVertexInputInfoForPointwise(int sourceCount, int targetCount, Function<Integer, Integer> numOfSubpartitionsRetriever, boolean isDynamicGraph) {
        ArrayList<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<ExecutionVertexInputInfo>();
        if (sourceCount >= targetCount) {
            for (int index = 0; index < targetCount; ++index) {
                int start = index * sourceCount / targetCount;
                int end = (index + 1) * sourceCount / targetCount;
                IndexRange partitionRange = new IndexRange(start, end - 1);
                IndexRange subpartitionRange = VertexInputInfoComputationUtils.computeConsumedSubpartitionRange(index, 1, () -> (Integer)numOfSubpartitionsRetriever.apply(start), isDynamicGraph, false, false);
                executionVertexInputInfos.add(new ExecutionVertexInputInfo(index, partitionRange, subpartitionRange));
            }
        } else {
            for (int partitionNum = 0; partitionNum < sourceCount; ++partitionNum) {
                int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount;
                int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount;
                int numConsumers = end - start;
                IndexRange partitionRange = new IndexRange(partitionNum, partitionNum);
                int finalPartitionNum = partitionNum;
                for (int i = start; i < end; ++i) {
                    IndexRange subpartitionRange = VertexInputInfoComputationUtils.computeConsumedSubpartitionRange(i, numConsumers, () -> (Integer)numOfSubpartitionsRetriever.apply(finalPartitionNum), isDynamicGraph, false, false);
                    executionVertexInputInfos.add(new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange));
                }
            }
        }
        return new JobVertexInputInfo(executionVertexInputInfos);
    }

    public static JobVertexInputInfo computeVertexInputInfoForAllToAll(int sourceCount, int targetCount, Function<Integer, Integer> numOfSubpartitionsRetriever, boolean isDynamicGraph, boolean isBroadcast, boolean isSingleSubpartitionContainsAllData) {
        ArrayList<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<ExecutionVertexInputInfo>();
        IndexRange partitionRange = new IndexRange(0, sourceCount - 1);
        for (int i = 0; i < targetCount; ++i) {
            IndexRange subpartitionRange = VertexInputInfoComputationUtils.computeConsumedSubpartitionRange(i, targetCount, () -> (Integer)numOfSubpartitionsRetriever.apply(0), isDynamicGraph, isBroadcast, isSingleSubpartitionContainsAllData);
            executionVertexInputInfos.add(new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange));
        }
        return new JobVertexInputInfo(executionVertexInputInfos);
    }

    @VisibleForTesting
    static IndexRange computeConsumedSubpartitionRange(int consumerSubtaskIndex, int numConsumers, Supplier<Integer> numOfSubpartitionsSupplier, boolean isDynamicGraph, boolean isBroadcast, boolean isSingleSubpartitionContainsAllData) {
        int consumerIndex = consumerSubtaskIndex % numConsumers;
        if (!isDynamicGraph) {
            return new IndexRange(consumerIndex, consumerIndex);
        }
        int numSubpartitions = numOfSubpartitionsSupplier.get();
        if (isBroadcast) {
            if (isSingleSubpartitionContainsAllData) {
                Preconditions.checkArgument(numSubpartitions == 1);
                return new IndexRange(0, 0);
            }
            return new IndexRange(0, numSubpartitions - 1);
        }
        Preconditions.checkArgument(consumerIndex < numConsumers);
        Preconditions.checkArgument(numConsumers <= numSubpartitions);
        int start = consumerIndex * numSubpartitions / numConsumers;
        int nextStart = (consumerIndex + 1) * numSubpartitions / numConsumers;
        return new IndexRange(start, nextStart - 1);
    }

    private VertexInputInfoComputationUtils() {
    }

    private static class IntermediateResultWrapper
    implements IntermediateResultInfo {
        private final IntermediateResult intermediateResult;

        IntermediateResultWrapper(IntermediateResult intermediateResult) {
            this.intermediateResult = Preconditions.checkNotNull(intermediateResult);
        }

        @Override
        public IntermediateDataSetID getResultId() {
            return this.intermediateResult.getId();
        }

        @Override
        public boolean isBroadcast() {
            return this.intermediateResult.isBroadcast();
        }

        @Override
        public boolean isSingleSubpartitionContainsAllData() {
            return this.intermediateResult.isSingleSubpartitionContainsAllData();
        }

        @Override
        public boolean isPointwise() {
            return this.intermediateResult.getConsumingDistributionPattern() == DistributionPattern.POINTWISE;
        }

        @Override
        public int getNumPartitions() {
            return this.intermediateResult.getNumberOfAssignedPartitions();
        }

        @Override
        public int getNumSubpartitions(int partitionIndex) {
            Preconditions.checkState(this.intermediateResult.getProducer().getGraph().isDynamic(), "This method should only be called for dynamic graph.");
            return this.intermediateResult.getPartitions()[partitionIndex].getNumberOfSubpartitions();
        }
    }
}

