/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.TieredStorageTestUtils;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.SortBufferAccumulator;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemoryManager;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemoryManagerImpl;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemorySpec;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

class SortBufferAccumulatorTest {
    private static final int NUM_TOTAL_BUFFERS = 1000;
    private static final int BUFFER_SIZE_BYTES = 1024;
    private static final float NUM_BUFFERS_TRIGGER_FLUSH_RATIO = 0.6f;
    private NetworkBufferPool globalPool;

    SortBufferAccumulatorTest() {
    }

    @BeforeEach
    void before() {
        this.globalPool = new NetworkBufferPool(1000, 1024);
    }

    @AfterEach
    void after() {
        this.globalPool.destroy();
    }

    @Test
    void testAccumulateRecordsAndGenerateBuffers() throws IOException {
        this.testAccumulateRecordsAndGenerateBuffers(true, Arrays.asList(Buffer.DataType.DATA_BUFFER, Buffer.DataType.DATA_BUFFER_WITH_CLEAR_END));
    }

    @Test
    void testAccumulateRecordsAndGenerateBuffersWithPartialRecordUnallowed() throws IOException {
        this.testAccumulateRecordsAndGenerateBuffers(false, Collections.singletonList(Buffer.DataType.DATA_BUFFER_WITH_CLEAR_END));
    }

    private void testAccumulateRecordsAndGenerateBuffers(boolean isPartialRecordAllowed, Collection<Buffer.DataType> expectedDataTypes) throws IOException {
        int numBuffers = 10;
        int numRecords = 1000;
        int indexEntrySize = 16;
        TieredStorageSubpartitionId subpartitionId = new TieredStorageSubpartitionId(0);
        Random random = new Random(1234L);
        TieredStorageMemoryManagerImpl memoryManager = this.createStorageMemoryManager(numBuffers);
        int numExpectBuffers = 0;
        int currentBufferWrittenBytes = 0;
        AtomicInteger numReceivedFinishedBuffer = new AtomicInteger(0);
        try (SortBufferAccumulator bufferAccumulator = new SortBufferAccumulator(1, 2, 1024, 0L, (TieredStorageMemoryManager)memoryManager, isPartialRecordAllowed);){
            bufferAccumulator.setup((subpartition, buffer, numRemainingBuffers) -> {
                Assertions.assertThat((Comparable)buffer.getDataType()).isIn((Iterable)expectedDataTypes);
                numReceivedFinishedBuffer.incrementAndGet();
                buffer.recycleBuffer();
            });
            boolean isBroadcastForPreviousRecord = false;
            for (int i = 0; i < numRecords; ++i) {
                int numBytes = random.nextInt(1024) + 1;
                ByteBuffer record = TieredStorageTestUtils.generateRandomData(numBytes, random);
                boolean isBroadcast = random.nextBoolean();
                bufferAccumulator.receive(record, subpartitionId, Buffer.DataType.DATA_BUFFER, isBroadcast);
                if (currentBufferWrittenBytes + numBytes + indexEntrySize > 1024 || i > 0 && isBroadcastForPreviousRecord != isBroadcast) {
                    ++numExpectBuffers;
                    currentBufferWrittenBytes = 0;
                }
                isBroadcastForPreviousRecord = isBroadcast;
                currentBufferWrittenBytes += numBytes + indexEntrySize;
            }
        }
        Assertions.assertThat((int)currentBufferWrittenBytes).isLessThan(1024);
        Assertions.assertThat((AtomicInteger)numReceivedFinishedBuffer).hasValue(numExpectBuffers += currentBufferWrittenBytes == 0 ? 0 : 1);
    }

    @Test
    void testWriteLargeRecord() throws IOException {
        this.testWriteLargeRecord(true);
    }

    @Test
    void testWriteLargeRecordWithPartialRecordUnallowed() throws IOException {
        this.testWriteLargeRecord(false);
    }

    private void testWriteLargeRecord(boolean isPartialRecordAllowed) throws IOException {
        int numBuffers = 15;
        Random random = new Random();
        TieredStorageMemoryManagerImpl memoryManager = this.createStorageMemoryManager(numBuffers);
        try (SortBufferAccumulator bufferAccumulator = new SortBufferAccumulator(1, 2, 1024, 0L, (TieredStorageMemoryManager)memoryManager, isPartialRecordAllowed);){
            AtomicInteger numReceivedBuffers = new AtomicInteger(0);
            bufferAccumulator.setup((subpartitionIndex, buffer, numRemainingBuffers) -> {
                numReceivedBuffers.getAndAdd(1);
                buffer.recycleBuffer();
            });
            ByteBuffer largeRecord = TieredStorageTestUtils.generateRandomData(1024 * numBuffers, random);
            bufferAccumulator.receive(largeRecord, new TieredStorageSubpartitionId(0), Buffer.DataType.DATA_BUFFER, false);
            Assertions.assertThat((AtomicInteger)numReceivedBuffers).hasValue(numBuffers);
        }
    }

    @Test
    void testNoBuffersForSort() throws IOException {
        int numBuffers = 10;
        int bufferSize = 1024;
        Random random = new Random(1111L);
        TieredStorageSubpartitionId subpartitionId = new TieredStorageSubpartitionId(0);
        TieredStorageMemoryManagerImpl memoryManager = this.createStorageMemoryManager(numBuffers);
        try (SortBufferAccumulator bufferAccumulator = new SortBufferAccumulator(1, 1, bufferSize, 0L, (TieredStorageMemoryManager)memoryManager, true);){
            bufferAccumulator.setup((subpartitionIndex, buffers, numRemainingBuffers) -> {});
            Assertions.assertThatThrownBy(() -> bufferAccumulator.receive(TieredStorageTestUtils.generateRandomData(1, random), subpartitionId, Buffer.DataType.DATA_BUFFER, false)).isInstanceOf(IllegalArgumentException.class);
        }
    }

    @Test
    void testCloseWithUnFinishedBuffers() throws IOException {
        int numBuffers = 10;
        TieredStorageMemoryManagerImpl tieredStorageMemoryManager = this.createStorageMemoryManager(numBuffers);
        SortBufferAccumulator bufferAccumulator = new SortBufferAccumulator(1, 2, 1024, 0L, (TieredStorageMemoryManager)tieredStorageMemoryManager, true);
        bufferAccumulator.setup((subpartition, buffer, numRemainingBuffers) -> buffer.recycleBuffer());
        bufferAccumulator.receive(TieredStorageTestUtils.generateRandomData(1, new Random()), new TieredStorageSubpartitionId(0), Buffer.DataType.DATA_BUFFER, false);
        Assertions.assertThat((int)tieredStorageMemoryManager.numOwnerRequestedBuffer((Object)bufferAccumulator)).isEqualTo(2);
        bufferAccumulator.close();
        Assertions.assertThat((int)tieredStorageMemoryManager.numOwnerRequestedBuffer((Object)bufferAccumulator)).isZero();
    }

    private TieredStorageMemoryManagerImpl createStorageMemoryManager(int numBuffersInBufferPool) throws IOException {
        BufferPool bufferPool = this.globalPool.createBufferPool(numBuffersInBufferPool, numBuffersInBufferPool, numBuffersInBufferPool);
        TieredStorageMemoryManagerImpl storageMemoryManager = new TieredStorageMemoryManagerImpl(0.6f, true);
        storageMemoryManager.setup(bufferPool, Collections.singletonList(new TieredStorageMemorySpec((Object)this, 1)));
        return storageMemoryManager;
    }
}

