/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler;

import java.util.AbstractMap;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
import org.apache.flink.runtime.executiongraph.IntermediateResultPartitionTest;
import org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobGraphTestUtils;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
import org.apache.flink.runtime.scheduler.SsgNetworkMemoryCalculationUtils;
import org.apache.flink.runtime.scheduler.VertexParallelismStore;
import org.apache.flink.runtime.scheduler.adaptivebatch.AdaptiveBatchScheduler;
import org.apache.flink.runtime.shuffle.PartitionDescriptor;
import org.apache.flink.runtime.shuffle.ProducerDescriptor;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
import org.apache.flink.runtime.shuffle.ShuffleMaster;
import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.runtime.util.JobVertexConnectionUtils;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

class SsgNetworkMemoryCalculationUtilsTest {
    @RegisterExtension
    private static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_EXTENSION = TestingUtils.defaultExecutorExtension();
    private static final TestShuffleMaster SHUFFLE_MASTER = new TestShuffleMaster();
    private static final ResourceProfile DEFAULT_RESOURCE = ResourceProfile.fromResources((double)1.0, (int)100);

    SsgNetworkMemoryCalculationUtilsTest() {
    }

    @Test
    void testGenerateEnrichedResourceProfile() throws Exception {
        this.testGenerateEnrichedResourceProfile(ResultPartitionType.PIPELINED, new MemorySize((long)(TestShuffleMaster.computeRequiredShuffleMemoryBytes(0, 2) + TestShuffleMaster.computeRequiredShuffleMemoryBytes(1, 12))), new MemorySize((long)TestShuffleMaster.computeRequiredShuffleMemoryBytes(10, 0)));
        this.testGenerateEnrichedResourceProfile(ResultPartitionType.BLOCKING, new MemorySize((long)(TestShuffleMaster.computeRequiredShuffleMemoryBytes(0, 2) + TestShuffleMaster.computeRequiredShuffleMemoryBytes(1, 6))), new MemorySize((long)TestShuffleMaster.computeRequiredShuffleMemoryBytes(5, 0)));
    }

    private void testGenerateEnrichedResourceProfile(ResultPartitionType resultPartitionType, MemorySize group0MemorySize, MemorySize group1MemorySize) throws Exception {
        SlotSharingGroup slotSharingGroup0 = new SlotSharingGroup();
        slotSharingGroup0.setResourceProfile(DEFAULT_RESOURCE);
        SlotSharingGroup slotSharingGroup1 = new SlotSharingGroup();
        slotSharingGroup1.setResourceProfile(DEFAULT_RESOURCE);
        this.createExecutionGraphAndEnrichNetworkMemory(Arrays.asList(slotSharingGroup0, slotSharingGroup0, slotSharingGroup1), resultPartitionType);
        Assertions.assertThat((Comparable)slotSharingGroup0.getResourceProfile().getNetworkMemory()).isEqualTo((Object)group0MemorySize);
        Assertions.assertThat((Comparable)slotSharingGroup1.getResourceProfile().getNetworkMemory()).isEqualTo((Object)group1MemorySize);
    }

    @Test
    void testGenerateUnknownResourceProfile() throws Exception {
        SlotSharingGroup slotSharingGroup0 = new SlotSharingGroup();
        slotSharingGroup0.setResourceProfile(ResourceProfile.UNKNOWN);
        SlotSharingGroup slotSharingGroup1 = new SlotSharingGroup();
        slotSharingGroup1.setResourceProfile(ResourceProfile.UNKNOWN);
        this.createExecutionGraphAndEnrichNetworkMemory(Arrays.asList(slotSharingGroup0, slotSharingGroup0, slotSharingGroup1), ResultPartitionType.PIPELINED);
        Assertions.assertThat((Object)slotSharingGroup0.getResourceProfile()).isEqualTo((Object)ResourceProfile.UNKNOWN);
        Assertions.assertThat((Object)slotSharingGroup1.getResourceProfile()).isEqualTo((Object)ResourceProfile.UNKNOWN);
    }

    @Test
    void testGenerateEnrichedResourceProfileForDynamicGraph() throws Exception {
        List<SlotSharingGroup> slotSharingGroups = Arrays.asList(new SlotSharingGroup(), new SlotSharingGroup(), new SlotSharingGroup());
        for (SlotSharingGroup group : slotSharingGroups) {
            group.setResourceProfile(DEFAULT_RESOURCE);
        }
        DefaultExecutionGraph executionGraph = this.createDynamicExecutionGraph(slotSharingGroups, 20);
        Iterator jobVertices = executionGraph.getVerticesTopologically().iterator();
        ExecutionJobVertex source = (ExecutionJobVertex)jobVertices.next();
        ExecutionJobVertex map = (ExecutionJobVertex)jobVertices.next();
        ExecutionJobVertex sink = (ExecutionJobVertex)jobVertices.next();
        executionGraph.initializeJobVertex(source, 0L);
        this.triggerComputeNumOfSubpartitions(source.getProducedDataSets()[0]);
        map.setParallelism(5);
        executionGraph.initializeJobVertex(map, 0L);
        this.triggerComputeNumOfSubpartitions(map.getProducedDataSets()[0]);
        sink.setParallelism(7);
        executionGraph.initializeJobVertex(sink, 0L);
        this.assertNetworkMemory(slotSharingGroups, Arrays.asList(new MemorySize((long)TestShuffleMaster.computeRequiredShuffleMemoryBytes(0, 5)), new MemorySize((long)TestShuffleMaster.computeRequiredShuffleMemoryBytes(5, 20)), new MemorySize((long)TestShuffleMaster.computeRequiredShuffleMemoryBytes(15, 0))));
    }

    private void triggerComputeNumOfSubpartitions(IntermediateResult result) {
        for (IntermediateResultPartition partition : result.getPartitions()) {
            partition.getNumberOfSubpartitions();
        }
    }

    private void assertNetworkMemory(List<SlotSharingGroup> slotSharingGroups, List<MemorySize> networkMemory) {
        Assertions.assertThat(networkMemory).hasSameSizeAs(slotSharingGroups);
        for (int i = 0; i < slotSharingGroups.size(); ++i) {
            Assertions.assertThat((Comparable)slotSharingGroups.get(i).getResourceProfile().getNetworkMemory()).isEqualTo((Object)networkMemory.get(i));
        }
    }

    @Test
    void testGetMaxInputChannelNumForResultForAllToAll() throws Exception {
        this.testGetMaxInputChannelNumForResult(DistributionPattern.ALL_TO_ALL, 5, 20, 7, 15);
    }

    @Test
    void testGetMaxInputChannelNumForResultForPointWise() throws Exception {
        this.testGetMaxInputChannelNumForResult(DistributionPattern.POINTWISE, 5, 20, 3, 8);
        this.testGetMaxInputChannelNumForResult(DistributionPattern.POINTWISE, 5, 20, 5, 4);
        this.testGetMaxInputChannelNumForResult(DistributionPattern.POINTWISE, 5, 20, 7, 4);
    }

    private void testGetMaxInputChannelNumForResult(DistributionPattern distributionPattern, int producerParallelism, int consumerMaxParallelism, int decidedConsumerParallelism, int expectedNumChannels) throws Exception {
        DefaultExecutionGraph eg = (DefaultExecutionGraph)IntermediateResultPartitionTest.createExecutionGraph(producerParallelism, -1, consumerMaxParallelism, distributionPattern, true, (ScheduledExecutorService)EXECUTOR_EXTENSION.getExecutor());
        Iterator vertexIterator = eg.getVerticesTopologically().iterator();
        ExecutionJobVertex producer = (ExecutionJobVertex)vertexIterator.next();
        ExecutionJobVertex consumer = (ExecutionJobVertex)vertexIterator.next();
        eg.initializeJobVertex(producer, 0L);
        IntermediateResult result = producer.getProducedDataSets()[0];
        this.triggerComputeNumOfSubpartitions(result);
        consumer.setParallelism(decidedConsumerParallelism);
        eg.initializeJobVertex(consumer, 0L);
        HashMap maxInputChannelNums = new HashMap();
        HashMap inputPartitionTypes = new HashMap();
        SsgNetworkMemoryCalculationUtils.getMaxInputChannelInfoForDynamicGraph((ExecutionJobVertex)consumer, maxInputChannelNums, inputPartitionTypes);
        Assertions.assertThat(maxInputChannelNums).containsExactly(new Map.Entry[]{new AbstractMap.SimpleEntry<IntermediateDataSetID, Integer>(result.getId(), expectedNumChannels)});
        Assertions.assertThat(inputPartitionTypes).containsExactly(new Map.Entry[]{new AbstractMap.SimpleEntry<IntermediateDataSetID, ResultPartitionType>(result.getId(), result.getResultType())});
    }

    private DefaultExecutionGraph createDynamicExecutionGraph(List<SlotSharingGroup> slotSharingGroups, int defaultMaxParallelism) throws Exception {
        JobGraph jobGraph = SsgNetworkMemoryCalculationUtilsTest.createJobGraph(slotSharingGroups, Arrays.asList(4, -1, -1), ResultPartitionType.BLOCKING);
        VertexParallelismStore vertexParallelismStore = AdaptiveBatchScheduler.computeVertexParallelismStoreForDynamicGraph((Iterable)jobGraph.getVertices(), (int)defaultMaxParallelism);
        return TestingDefaultExecutionGraphBuilder.newBuilder().setJobGraph(jobGraph).setVertexParallelismStore(vertexParallelismStore).setShuffleMaster(SHUFFLE_MASTER).buildDynamicGraph((ScheduledExecutorService)EXECUTOR_EXTENSION.getExecutor());
    }

    private void createExecutionGraphAndEnrichNetworkMemory(List<SlotSharingGroup> slotSharingGroups, ResultPartitionType resultPartitionType) throws Exception {
        TestingDefaultExecutionGraphBuilder.newBuilder().setJobGraph(SsgNetworkMemoryCalculationUtilsTest.createJobGraph(slotSharingGroups, Arrays.asList(4, 5, 6), resultPartitionType)).setShuffleMaster(SHUFFLE_MASTER).build((ScheduledExecutorService)EXECUTOR_EXTENSION.getExecutor());
    }

    private static JobGraph createJobGraph(List<SlotSharingGroup> slotSharingGroups, List<Integer> parallelisms, ResultPartitionType resultPartitionType) {
        Assertions.assertThat(slotSharingGroups).hasSize(3);
        Assertions.assertThat(parallelisms).hasSize(3);
        JobVertex source = new JobVertex("source");
        source.setInvokableClass(NoOpInvokable.class);
        SsgNetworkMemoryCalculationUtilsTest.trySetParallelism(source, parallelisms.get(0));
        source.setSlotSharingGroup(slotSharingGroups.get(0));
        JobVertex map = new JobVertex("map");
        map.setInvokableClass(NoOpInvokable.class);
        SsgNetworkMemoryCalculationUtilsTest.trySetParallelism(map, parallelisms.get(1));
        map.setSlotSharingGroup(slotSharingGroups.get(1));
        JobVertex sink = new JobVertex("sink");
        sink.setInvokableClass(NoOpInvokable.class);
        SsgNetworkMemoryCalculationUtilsTest.trySetParallelism(sink, parallelisms.get(2));
        sink.setSlotSharingGroup(slotSharingGroups.get(2));
        JobVertexConnectionUtils.connectNewDataSetAsInput(map, source, DistributionPattern.POINTWISE, resultPartitionType);
        if (resultPartitionType == ResultPartitionType.BLOCKING) {
            IntermediateDataSetID dataSetId = new IntermediateDataSetID();
            JobVertexConnectionUtils.connectNewDataSetAsInput(sink, map, DistributionPattern.ALL_TO_ALL, resultPartitionType, dataSetId, false);
            JobVertexConnectionUtils.connectNewDataSetAsInput(sink, map, DistributionPattern.ALL_TO_ALL, resultPartitionType, dataSetId, false);
        } else {
            JobVertexConnectionUtils.connectNewDataSetAsInput(sink, map, DistributionPattern.ALL_TO_ALL, resultPartitionType);
            JobVertexConnectionUtils.connectNewDataSetAsInput(sink, map, DistributionPattern.ALL_TO_ALL, resultPartitionType);
        }
        if (!resultPartitionType.isBlockingOrBlockingPersistentResultPartition()) {
            return JobGraphTestUtils.streamingJobGraph(source, map, sink);
        }
        return JobGraphTestUtils.batchJobGraph(source, map, sink);
    }

    private static void trySetParallelism(JobVertex jobVertex, int parallelism) {
        if (parallelism > 0) {
            jobVertex.setParallelism(parallelism);
        }
    }

    private static class TestShuffleMaster
    implements ShuffleMaster<ShuffleDescriptor> {
        private TestShuffleMaster() {
        }

        public CompletableFuture<ShuffleDescriptor> registerPartitionWithProducer(JobID jobID, PartitionDescriptor partitionDescriptor, ProducerDescriptor producerDescriptor) {
            return null;
        }

        public void releasePartitionExternally(ShuffleDescriptor shuffleDescriptor) {
        }

        public MemorySize computeShuffleMemorySizeForTask(TaskInputsOutputsDescriptor desc) {
            int numTotalChannels = desc.getInputChannelNums().values().stream().mapToInt(Integer::intValue).sum();
            int numTotalSubpartitions = desc.getSubpartitionNums().values().stream().mapToInt(Integer::intValue).sum();
            return new MemorySize((long)TestShuffleMaster.computeRequiredShuffleMemoryBytes(numTotalChannels, numTotalSubpartitions));
        }

        static int computeRequiredShuffleMemoryBytes(int numTotalChannels, int numTotalSubpartitions) {
            return numTotalChannels * 10000 + numTotalSubpartitions;
        }
    }
}

