/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.io.network.partition;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Collection;
import java.util.Queue;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Consumer;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.event.AbstractEvent;
import org.apache.flink.runtime.io.disk.BatchShuffleReadBufferPool;
import org.apache.flink.runtime.io.disk.FileChannelManager;
import org.apache.flink.runtime.io.disk.FileChannelManagerImpl;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.CompositeBuffer;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
import org.apache.flink.runtime.io.network.partition.BufferWithSubpartition;
import org.apache.flink.runtime.io.network.partition.DataBufferTest;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.partition.ResultSubpartition;
import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
import org.apache.flink.runtime.io.network.partition.SortMergeResultPartition;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
import org.apache.flink.testutils.junit.utils.TempDirUtils;
import org.apache.flink.util.Preconditions;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.io.TempDir;

@ExtendWith(value={ParameterizedTestExtension.class})
class SortMergeResultPartitionTest {
    private static final int bufferSize = 1024;
    private static final int totalBuffers = 1000;
    private static final int totalBytes = 0x2000000;
    private static final int numThreads = 4;
    @Parameter
    public boolean useHashDataBuffer;
    private final TestBufferAvailabilityListener listener = new TestBufferAvailabilityListener();
    private FileChannelManager fileChannelManager;
    private NetworkBufferPool globalPool;
    private BatchShuffleReadBufferPool readBufferPool;
    private ExecutorService readIOExecutor;
    @TempDir
    private Path tmpFolder;

    SortMergeResultPartitionTest() {
    }

    @BeforeEach
    void setUp() throws IOException {
        this.fileChannelManager = new FileChannelManagerImpl(new String[]{TempDirUtils.newFolder((Path)this.tmpFolder).toString()}, "testing");
        this.globalPool = new NetworkBufferPool(1000, 1024);
        this.readBufferPool = new BatchShuffleReadBufferPool(0x2000000L, 1024);
        this.readIOExecutor = Executors.newFixedThreadPool(4);
    }

    @AfterEach
    void shutdown() throws Exception {
        this.fileChannelManager.close();
        this.globalPool.destroy();
        this.readBufferPool.destroy();
        this.readIOExecutor.shutdown();
    }

    @Parameters(name="useHashDataBuffer={0}")
    public static Collection<Boolean> parameters() {
        return Arrays.asList(false, true);
    }

    @TestTemplate
    void testWriteAndRead() throws Exception {
        ByteBuffer record;
        int numBuffers = this.useHashDataBuffer ? 100 : 15;
        int numSubpartitions = 10;
        int numRecords = 1000;
        Random random = new Random();
        BufferPool bufferPool = this.globalPool.createBufferPool(numBuffers, numBuffers);
        SortMergeResultPartition partition = this.createSortMergedPartition(numSubpartitions, bufferPool);
        Queue[] dataWritten = new Queue[numSubpartitions];
        Queue[] buffersRead = new Queue[numSubpartitions];
        for (int i = 0; i < numSubpartitions; ++i) {
            dataWritten[i] = new ArrayDeque();
            buffersRead[i] = new ArrayDeque();
        }
        int[] numBytesWritten = new int[numSubpartitions];
        int[] numBytesRead = new int[numSubpartitions];
        Arrays.fill(numBytesWritten, 0);
        Arrays.fill(numBytesRead, 0);
        for (int i = 0; i < numRecords; ++i) {
            int subpartition;
            record = this.generateRandomData(random.nextInt(2048) + 1, random);
            boolean isBroadCast = random.nextBoolean();
            if (isBroadCast) {
                partition.broadcastRecord(record);
                for (subpartition = 0; subpartition < numSubpartitions; ++subpartition) {
                    this.recordDataWritten(record, dataWritten, subpartition, numBytesWritten, Buffer.DataType.DATA_BUFFER);
                }
                continue;
            }
            subpartition = random.nextInt(numSubpartitions);
            partition.emitRecord(record, subpartition);
            this.recordDataWritten(record, dataWritten, subpartition, numBytesWritten, Buffer.DataType.DATA_BUFFER);
        }
        partition.finish();
        partition.close();
        for (int subpartition = 0; subpartition < numSubpartitions; ++subpartition) {
            record = EventSerializer.toSerializedEvent((AbstractEvent)EndOfPartitionEvent.INSTANCE);
            this.recordDataWritten(record, dataWritten, subpartition, numBytesWritten, Buffer.DataType.END_OF_PARTITION);
        }
        ResultSubpartitionView[] views = this.createSubpartitionViews(partition, numSubpartitions);
        this.readData(views, bufferWithSubpartition -> {
            Buffer buffer = bufferWithSubpartition.getBuffer();
            int subpartition = bufferWithSubpartition.getSubpartitionIndex();
            int numBytes = buffer.readableBytes();
            int n = subpartition;
            numBytesRead[n] = numBytesRead[n] + numBytes;
            MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment((int)numBytes);
            Buffer fullBuffer = ((CompositeBuffer)buffer).getFullBufferData(MemorySegmentFactory.allocateUnpooledSegment((int)numBytes));
            segment.put(0, fullBuffer.getNioBufferReadable(), fullBuffer.readableBytes());
            buffersRead[subpartition].add(new NetworkBuffer(segment, ignore -> {}, fullBuffer.getDataType(), fullBuffer.isCompressed(), fullBuffer.readableBytes()));
            fullBuffer.recycleBuffer();
        });
        DataBufferTest.checkWriteReadResult(numSubpartitions, numBytesWritten, numBytesRead, dataWritten, buffersRead);
    }

    private void recordDataWritten(ByteBuffer record, Queue<DataBufferTest.DataAndType>[] dataWritten, int subpartition, int[] numBytesWritten, Buffer.DataType dataType) {
        record.rewind();
        dataWritten[subpartition].add(new DataBufferTest.DataAndType(record, dataType));
        int n = subpartition;
        numBytesWritten[n] = numBytesWritten[n] + record.remaining();
    }

    private ByteBuffer generateRandomData(int dataSize, Random random) {
        byte[] dataWritten = new byte[dataSize];
        random.nextBytes(dataWritten);
        return ByteBuffer.wrap(dataWritten);
    }

    private long readData(ResultSubpartitionView[] views, Consumer<BufferWithSubpartition> bufferProcessor) throws Exception {
        int dataSize = 0;
        int numEndOfPartitionEvents = 0;
        while (numEndOfPartitionEvents < views.length) {
            this.listener.waitForData();
            for (int subpartition = 0; subpartition < views.length; ++subpartition) {
                ResultSubpartitionView view = views[subpartition];
                ResultSubpartition.BufferAndBacklog bufferAndBacklog = view.getNextBuffer();
                while (bufferAndBacklog != null) {
                    Buffer buffer = bufferAndBacklog.buffer();
                    bufferProcessor.accept(new BufferWithSubpartition(buffer, subpartition));
                    dataSize += buffer.readableBytes();
                    if (!buffer.isBuffer()) {
                        ++numEndOfPartitionEvents;
                        Assertions.assertThat((boolean)view.getAvailabilityAndBacklog(true).isAvailable()).isFalse();
                        view.releaseAllResources();
                    }
                    bufferAndBacklog = view.getNextBuffer();
                }
            }
        }
        return dataSize;
    }

    private ResultSubpartitionView[] createSubpartitionViews(SortMergeResultPartition partition, int numSubpartitions) throws Exception {
        ResultSubpartitionView[] views = new ResultSubpartitionView[numSubpartitions];
        for (int subpartition = 0; subpartition < numSubpartitions; ++subpartition) {
            views[subpartition] = partition.createSubpartitionView(new ResultSubpartitionIndexSet(subpartition), (BufferAvailabilityListener)this.listener);
        }
        return views;
    }

    @TestTemplate
    void testWriteLargeRecord() throws Exception {
        int numBuffers = this.useHashDataBuffer ? 100 : 15;
        BufferPool bufferPool = this.globalPool.createBufferPool(numBuffers, numBuffers);
        SortMergeResultPartition partition = this.createSortMergedPartition(10, bufferPool);
        ByteBuffer recordWritten = this.generateRandomData(1024 * numBuffers, new Random());
        partition.emitRecord(recordWritten, 0);
        Assertions.assertThat((int)bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(this.useHashDataBuffer ? numBuffers : 0);
        partition.finish();
        partition.close();
        ResultSubpartitionView view = partition.createSubpartitionView(new ResultSubpartitionIndexSet(0), (BufferAvailabilityListener)this.listener);
        ByteBuffer recordRead = ByteBuffer.allocate(1024 * numBuffers);
        this.readData(new ResultSubpartitionView[]{view}, bufferWithSubpartition -> {
            int numBytes;
            MemorySegment segment;
            Buffer buffer = bufferWithSubpartition.getBuffer();
            Buffer fullBuffer = ((CompositeBuffer)buffer).getFullBufferData(segment = MemorySegmentFactory.allocateUnpooledSegment((int)(numBytes = buffer.readableBytes())));
            if (fullBuffer.isBuffer()) {
                ByteBuffer byteBuffer = ByteBuffer.allocate(fullBuffer.readableBytes()).put(fullBuffer.getNioBufferReadable());
                recordRead.put((ByteBuffer)byteBuffer.flip());
            }
            fullBuffer.recycleBuffer();
        });
        recordWritten.rewind();
        recordRead.flip();
        Assertions.assertThat((Comparable)recordRead).isEqualTo((Object)recordWritten);
    }

    @TestTemplate
    void testDataBroadcast() throws Exception {
        int numSubpartitions = 10;
        int numBuffers = this.useHashDataBuffer ? 100 : 15;
        int numRecords = 10000;
        BufferPool bufferPool = this.globalPool.createBufferPool(numBuffers, numBuffers);
        SortMergeResultPartition partition = this.createSortMergedPartition(numSubpartitions, bufferPool);
        for (int i = 0; i < numRecords; ++i) {
            ByteBuffer record = this.generateRandomData(1024, new Random());
            partition.broadcastRecord(record);
        }
        partition.finish();
        partition.close();
        int eventSize = EventSerializer.toSerializedEvent((AbstractEvent)EndOfPartitionEvent.INSTANCE).remaining();
        long dataSize = numSubpartitions * numRecords * 1024 + numSubpartitions * eventSize;
        Assertions.assertThat((Object)partition.getResultFile()).isNotNull();
        Assertions.assertThat((int)((String[])Preconditions.checkNotNull((Object)this.fileChannelManager.getPaths()[0].list())).length).isEqualTo(2);
        for (File file : (File[])Preconditions.checkNotNull((Object)this.fileChannelManager.getPaths()[0].listFiles())) {
            if (!file.getName().endsWith(".shuffle.data")) continue;
            Assertions.assertThat((long)file.length()).isLessThan((long)(numSubpartitions * numRecords * 1024));
        }
        ResultSubpartitionView[] views = this.createSubpartitionViews(partition, numSubpartitions);
        long dataRead = this.readData(views, bufferWithSubpartition -> bufferWithSubpartition.getBuffer().recycleBuffer());
        Assertions.assertThat((long)dataRead).isEqualTo(dataSize);
    }

    @TestTemplate
    void testReleaseWhileWriting() throws Exception {
        int numBuffers = this.useHashDataBuffer ? 100 : 15;
        BufferPool bufferPool = this.globalPool.createBufferPool(numBuffers, numBuffers);
        SortMergeResultPartition partition = this.createSortMergedPartition(10, bufferPool);
        Assertions.assertThat((int)bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers);
        partition.emitRecord(ByteBuffer.allocate(1024 * (numBuffers - 1)), 0);
        partition.emitRecord(ByteBuffer.allocate(1024 * (numBuffers - 1)), 1);
        partition.emitRecord(ByteBuffer.allocate(1024), 2);
        Assertions.assertThat((Object)partition.getResultFile()).isNull();
        Assertions.assertThat((int)this.fileChannelManager.getPaths()[0].list().length).isEqualTo(2);
        partition.release();
        Assertions.assertThatThrownBy(() -> partition.emitRecord(ByteBuffer.allocate(1024 * numBuffers), 2)).isInstanceOf(IllegalStateException.class);
        Assertions.assertThat((int)this.fileChannelManager.getPaths()[0].list().length).isEqualTo(0);
    }

    @TestTemplate
    void testRelease() throws Exception {
        int numBuffers = this.useHashDataBuffer ? 100 : 15;
        BufferPool bufferPool = this.globalPool.createBufferPool(numBuffers, numBuffers);
        SortMergeResultPartition partition = this.createSortMergedPartition(10, bufferPool);
        Assertions.assertThat((int)bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers);
        partition.emitRecord(ByteBuffer.allocate(1024 * (numBuffers - 1)), 0);
        partition.emitRecord(ByteBuffer.allocate(1024 * (numBuffers - 1)), 1);
        partition.finish();
        partition.close();
        Assertions.assertThat((int)partition.getResultFile().getNumRegions()).isEqualTo(3);
        Assertions.assertThat((int)((String[])Preconditions.checkNotNull((Object)this.fileChannelManager.getPaths()[0].list())).length).isEqualTo(2);
        ResultSubpartitionView view = partition.createSubpartitionView(new ResultSubpartitionIndexSet(0), (BufferAvailabilityListener)this.listener);
        partition.release();
        while (!view.isReleased() && partition.getResultFile() != null) {
            ResultSubpartition.BufferAndBacklog bufferAndBacklog = view.getNextBuffer();
            if (bufferAndBacklog == null) continue;
            bufferAndBacklog.buffer().recycleBuffer();
        }
        while (partition.getResultFile() != null) {
            Thread.sleep(100L);
        }
        Assertions.assertThat((int)((String[])Preconditions.checkNotNull((Object)this.fileChannelManager.getPaths()[0].list())).length).isEqualTo(0);
    }

    @TestTemplate
    void testCloseReleasesAllBuffers() throws Exception {
        int numBuffers = this.useHashDataBuffer ? 100 : 15;
        BufferPool bufferPool = this.globalPool.createBufferPool(numBuffers, numBuffers);
        SortMergeResultPartition partition = this.createSortMergedPartition(10, bufferPool);
        Assertions.assertThat((int)bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers);
        partition.emitRecord(ByteBuffer.allocate(1024 * (numBuffers - 1)), 5);
        Assertions.assertThat((int)bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(this.useHashDataBuffer ? numBuffers : 0);
        partition.close();
        Assertions.assertThat((boolean)bufferPool.isDestroyed()).isTrue();
        Assertions.assertThat((int)this.globalPool.getNumberOfAvailableMemorySegments()).isEqualTo(1000);
    }

    @TestTemplate
    void testReadUnfinishedPartition() throws Exception {
        BufferPool bufferPool = this.globalPool.createBufferPool(10, 10);
        SortMergeResultPartition partition = this.createSortMergedPartition(10, bufferPool);
        Assertions.assertThatThrownBy(() -> partition.createSubpartitionView(new ResultSubpartitionIndexSet(0), (BufferAvailabilityListener)this.listener)).isInstanceOf(IllegalStateException.class);
        bufferPool.lazyDestroy();
    }

    @TestTemplate
    void testReadReleasedPartition() throws Exception {
        BufferPool bufferPool = this.globalPool.createBufferPool(10, 10);
        SortMergeResultPartition partition = this.createSortMergedPartition(10, bufferPool);
        partition.finish();
        partition.release();
        Assertions.assertThatThrownBy(() -> partition.createSubpartitionView(new ResultSubpartitionIndexSet(0), (BufferAvailabilityListener)this.listener)).isInstanceOf(IllegalStateException.class);
        bufferPool.lazyDestroy();
    }

    @TestTemplate
    void testNumBytesProducedCounterForUnicast() throws IOException {
        this.testResultPartitionBytesCounter(false);
    }

    @TestTemplate
    void testNumBytesProducedCounterForBroadcast() throws IOException {
        this.testResultPartitionBytesCounter(true);
    }

    @TestTemplate
    void testNetworkBufferReservation() throws IOException {
        int numBuffers = 10;
        BufferPool bufferPool = this.globalPool.createBufferPool(numBuffers, 2 * numBuffers);
        SortMergeResultPartition partition = this.createSortMergedPartition(1, bufferPool);
        Assertions.assertThat((int)bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers);
        partition.finish();
        partition.close();
    }

    @TestTemplate
    void testNoDeadlockOnSpecificConsumptionOrder() throws Exception {
        int numNetworkBuffers = 8192;
        NetworkBufferPool networkBufferPool = new NetworkBufferPool(numNetworkBuffers, 1024);
        BatchShuffleReadBufferPool readBufferPool = new BatchShuffleReadBufferPool(0x400000L, 1024);
        BufferPool bufferPool = networkBufferPool.createBufferPool(numNetworkBuffers, numNetworkBuffers);
        SortMergeResultPartition partition = this.createSortMergedPartition(1, bufferPool, readBufferPool);
        for (int i = 0; i < numNetworkBuffers; ++i) {
            partition.emitRecord(ByteBuffer.allocate(1024), 0);
        }
        partition.finish();
        partition.close();
        CountDownLatch condition1 = new CountDownLatch(1);
        CountDownLatch condition2 = new CountDownLatch(1);
        Runnable task1 = () -> {
            try {
                ResultSubpartitionView view = partition.createSubpartitionView(new ResultSubpartitionIndexSet(0), (BufferAvailabilityListener)this.listener);
                BufferPool bufferPool1 = networkBufferPool.createBufferPool(numNetworkBuffers / 2, numNetworkBuffers);
                SortMergeResultPartition partition1 = this.createSortMergedPartition(1, bufferPool1);
                this.readAndEmitData(view, partition1);
                condition1.countDown();
                condition2.await();
                this.readAndEmitAllData(view, partition1);
            }
            catch (Exception exception) {
                // empty catch block
            }
        };
        Thread consumer1 = new Thread(task1);
        consumer1.start();
        Runnable task2 = () -> {
            try {
                condition1.await();
                BufferPool bufferPool2 = networkBufferPool.createBufferPool(numNetworkBuffers / 2, numNetworkBuffers);
                condition2.countDown();
                SortMergeResultPartition partition2 = this.createSortMergedPartition(1, bufferPool2);
                ResultSubpartitionView view = partition.createSubpartitionView(new ResultSubpartitionIndexSet(0), (BufferAvailabilityListener)this.listener);
                this.readAndEmitAllData(view, partition2);
            }
            catch (Exception exception) {
                // empty catch block
            }
        };
        Thread consumer2 = new Thread(task2);
        consumer2.start();
        consumer1.join();
        consumer2.join();
    }

    private boolean readAndEmitData(ResultSubpartitionView view, SortMergeResultPartition partition) throws Exception {
        ResultSubpartition.BufferAndBacklog buffer;
        MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment((int)1024);
        while ((buffer = view.getNextBuffer()) == null) {
        }
        Buffer data = ((CompositeBuffer)buffer.buffer()).getFullBufferData(segment);
        partition.emitRecord(data.getNioBufferReadable(), 0);
        if (!data.isRecycled()) {
            data.recycleBuffer();
        }
        return buffer.buffer().isBuffer();
    }

    private void readAndEmitAllData(ResultSubpartitionView view, SortMergeResultPartition partition) throws Exception {
        while (this.readAndEmitData(view, partition)) {
        }
        partition.finish();
        partition.close();
    }

    private void testResultPartitionBytesCounter(boolean isBroadcast) throws IOException {
        int numBuffers = this.useHashDataBuffer ? 100 : 15;
        int numSubpartitions = 2;
        BufferPool bufferPool = this.globalPool.createBufferPool(numBuffers, numBuffers);
        SortMergeResultPartition partition = this.createSortMergedPartition(numSubpartitions, bufferPool);
        if (isBroadcast) {
            partition.broadcastRecord(ByteBuffer.allocate(1024));
            partition.finish();
            long[] subpartitionBytes = partition.resultPartitionBytes.createSnapshot().getSubpartitionBytes();
            Assertions.assertThat((long[])subpartitionBytes).containsExactly(new long[]{1028L, 1028L});
            Assertions.assertThat((long)partition.numBytesOut.getCount()).isEqualTo((long)(numSubpartitions * 1028));
        } else {
            partition.emitRecord(ByteBuffer.allocate(1024), 0);
            partition.emitRecord(ByteBuffer.allocate(2048), 1);
            partition.finish();
            long[] subpartitionBytes = partition.resultPartitionBytes.createSnapshot().getSubpartitionBytes();
            Assertions.assertThat((long[])subpartitionBytes).containsExactly(new long[]{1028L, 2052L});
            Assertions.assertThat((long)partition.numBytesOut.getCount()).isEqualTo((long)(3072 + numSubpartitions * 4));
        }
    }

    private SortMergeResultPartition createSortMergedPartition(int numSubpartitions, BufferPool bufferPool) throws IOException {
        return this.createSortMergedPartition(numSubpartitions, bufferPool, this.readBufferPool);
    }

    private SortMergeResultPartition createSortMergedPartition(int numSubpartitions, BufferPool bufferPool, BatchShuffleReadBufferPool readBufferPool) throws IOException {
        SortMergeResultPartition sortMergedResultPartition = new SortMergeResultPartition("SortMergedResultPartitionTest", 0, new ResultPartitionID(), ResultPartitionType.BLOCKING, numSubpartitions, numSubpartitions, readBufferPool, (Executor)this.readIOExecutor, new ResultPartitionManager(), this.fileChannelManager.createChannel().getPath(), null, () -> bufferPool);
        sortMergedResultPartition.setup();
        return sortMergedResultPartition;
    }

    private static final class TestBufferAvailabilityListener
    implements BufferAvailabilityListener {
        private int numNotifications;

        private TestBufferAvailabilityListener() {
        }

        public synchronized void notifyDataAvailable(ResultSubpartitionView view) {
            if (this.numNotifications == 0) {
                this.notifyAll();
            }
            ++this.numNotifications;
        }

        public synchronized void waitForData() throws InterruptedException {
            if (this.numNotifications == 0) {
                this.wait();
            }
            this.numNotifications = 0;
        }
    }
}

