package org.apache.flink.runtime.scheduler.adaptivebatch;

import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.class */
class AllToAllBlockingResultInfoTest {
    AllToAllBlockingResultInfoTest() {
    }

    @Test
    void testGetNumBytesProducedForNonBroadcast() {
        testGetNumBytesProduced(false, 192L);
    }

    @Test
    void testGetNumBytesProducedForBroadcast() {
        testGetNumBytesProduced(true, 96L);
    }

    @Test
    void testGetNumBytesProducedWithIndexRange() {
        AllToAllBlockingResultInfo allToAllBlockingResultInfo = new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false);
        allToAllBlockingResultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[]{32, 64}));
        allToAllBlockingResultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[]{128, 256}));
        Assertions.assertThat(allToAllBlockingResultInfo.getNumBytesProduced(new IndexRange(0, 1), new IndexRange(0, 0))).isEqualTo(160L);
    }

    @Test
    void testGetAggregatedSubpartitionBytes() {
        AllToAllBlockingResultInfo allToAllBlockingResultInfo = new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false);
        allToAllBlockingResultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[]{32, 64}));
        allToAllBlockingResultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[]{128, 256}));
        Assertions.assertThat(allToAllBlockingResultInfo.getAggregatedSubpartitionBytes()).containsExactly(new Long[]{160L, 320L});
    }

    @Test
    void testGetBytesWithPartialPartitionInfos() {
        AllToAllBlockingResultInfo allToAllBlockingResultInfo = new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false);
        allToAllBlockingResultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[]{32, 64}));
        allToAllBlockingResultInfo.getClass();
        Assertions.assertThatThrownBy(allToAllBlockingResultInfo::getNumBytesProduced).isInstanceOf(IllegalStateException.class);
        allToAllBlockingResultInfo.getClass();
        Assertions.assertThatThrownBy(allToAllBlockingResultInfo::getAggregatedSubpartitionBytes).isInstanceOf(IllegalStateException.class);
    }

    @Test
    void testRecordPartitionInfoMultiTimes() {
        AllToAllBlockingResultInfo allToAllBlockingResultInfo = new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false);
        ResultPartitionBytes resultPartitionBytes = new ResultPartitionBytes(new long[]{32, 64});
        ResultPartitionBytes resultPartitionBytes2 = new ResultPartitionBytes(new long[]{64, 128});
        ResultPartitionBytes resultPartitionBytes3 = new ResultPartitionBytes(new long[]{128, 256});
        ResultPartitionBytes resultPartitionBytes4 = new ResultPartitionBytes(new long[]{256, 512});
        allToAllBlockingResultInfo.recordPartitionInfo(0, resultPartitionBytes);
        Assertions.assertThat(allToAllBlockingResultInfo.getNumOfRecordedPartitions()).isEqualTo(1);
        allToAllBlockingResultInfo.resetPartitionInfo(0);
        Assertions.assertThat(allToAllBlockingResultInfo.getNumOfRecordedPartitions()).isEqualTo(0);
        allToAllBlockingResultInfo.recordPartitionInfo(0, resultPartitionBytes2);
        allToAllBlockingResultInfo.recordPartitionInfo(1, resultPartitionBytes3);
        Assertions.assertThat(allToAllBlockingResultInfo.getNumBytesProduced()).isEqualTo(576L);
        Assertions.assertThat(allToAllBlockingResultInfo.getAggregatedSubpartitionBytes()).containsExactly(new Long[]{192L, 384L});
        Assertions.assertThat(allToAllBlockingResultInfo.getNumOfRecordedPartitions()).isEqualTo(0);
        allToAllBlockingResultInfo.resetPartitionInfo(0);
        allToAllBlockingResultInfo.recordPartitionInfo(0, resultPartitionBytes4);
        Assertions.assertThat(allToAllBlockingResultInfo.getNumBytesProduced()).isEqualTo(576L);
        Assertions.assertThat(allToAllBlockingResultInfo.getAggregatedSubpartitionBytes()).containsExactly(new Long[]{192L, 384L});
        Assertions.assertThat(allToAllBlockingResultInfo.getNumOfRecordedPartitions()).isEqualTo(0);
    }

    private void testGetNumBytesProduced(boolean z, long j) {
        AllToAllBlockingResultInfo allToAllBlockingResultInfo = new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, z);
        allToAllBlockingResultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[]{32, 32}));
        allToAllBlockingResultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[]{64, 64}));
        Assertions.assertThat(allToAllBlockingResultInfo.getNumBytesProduced()).isEqualTo(j);
    }
}
