/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.state;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.flink.api.common.JobID;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.core.memory.MemorySegmentProvider;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.CheckpointType;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.SnapshotType;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriterImpl;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.checkpoint.channel.SequentialChannelStateReaderImpl;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
import org.apache.flink.runtime.io.network.partition.BufferWritingResultPartition;
import org.apache.flink.runtime.io.network.partition.NoOpBufferAvailablityListener;
import org.apache.flink.runtime.io.network.partition.ResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionBuilder;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.partition.ResultSubpartition;
import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.CheckpointStorage;
import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.function.SupplierWithException;
import org.junit.Assert;
import org.junit.Test;

public class ChannelPersistenceITCase {
    private static final Random RANDOM = new Random(System.currentTimeMillis());
    private static final JobID JOB_ID = new JobID();
    private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
    private static final int SUBTASK_INDEX = 0;

    @Test
    public void testUpstreamBlocksAfterRecoveringState() throws Exception {
        this.upstreamBlocksAfterRecoveringState(ResultPartitionType.PIPELINED);
    }

    @Test
    public void testNotBlocksAfterRecoveringStateForApproximateLocalRecovery() throws Exception {
        this.upstreamBlocksAfterRecoveringState(ResultPartitionType.PIPELINED_APPROXIMATE);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testReadWritten() throws Exception {
        byte[] inputChannelInfoData = this.randomBytes(1024);
        byte[] resultSubpartitionInfoData = this.randomBytes(1024);
        byte[] resultSubpartitionInfoFutureData = this.randomBytes(1024);
        int partitionIndex = 0;
        SequentialChannelStateReaderImpl reader = new SequentialChannelStateReaderImpl(this.toTaskStateSnapshot(this.write(1L, Collections.singletonMap(new InputChannelInfo(0, 0), inputChannelInfoData), Collections.singletonMap(new ResultSubpartitionInfo(partitionIndex, 0), resultSubpartitionInfoData), Collections.singletonMap(new ResultSubpartitionInfo(partitionIndex, 1), resultSubpartitionInfoFutureData))));
        NetworkBufferPool networkBufferPool = new NetworkBufferPool(6, 1024);
        try {
            int numChannels = 1;
            SingleInputGate gate = this.buildGate(networkBufferPool, numChannels);
            reader.readInputData(new InputGate[]{gate});
            Assert.assertArrayEquals((byte[])inputChannelInfoData, (byte[])this.collectBytes(() -> ((InputGate)gate).pollNext(), BufferOrEvent::getBuffer));
            int subpartitions = 2;
            BufferWritingResultPartition resultPartition = this.buildResultPartition(networkBufferPool, ResultPartitionType.PIPELINED, partitionIndex, subpartitions);
            reader.readOutputData((ResultPartitionWriter[])new BufferWritingResultPartition[]{resultPartition}, false);
            ResultSubpartitionView view = resultPartition.createSubpartitionView(0, (BufferAvailabilityListener)new NoOpBufferAvailablityListener());
            Assert.assertArrayEquals((byte[])resultSubpartitionInfoData, (byte[])this.collectBytes(() -> Optional.ofNullable(view.getNextBuffer()), ResultSubpartition.BufferAndBacklog::buffer));
            ResultSubpartitionView futureView = resultPartition.createSubpartitionView(1, (BufferAvailabilityListener)new NoOpBufferAvailablityListener());
            Assert.assertArrayEquals((byte[])resultSubpartitionInfoFutureData, (byte[])this.collectBytes(() -> Optional.ofNullable(futureView.getNextBuffer()), ResultSubpartition.BufferAndBacklog::buffer));
        }
        finally {
            networkBufferPool.destroy();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void upstreamBlocksAfterRecoveringState(ResultPartitionType type) throws Exception {
        NetworkBufferPool networkBufferPool = new NetworkBufferPool(4, 1024);
        byte[] dataAfterRecovery = this.randomBytes(1024);
        try {
            BufferWritingResultPartition resultPartition = this.buildResultPartition(networkBufferPool, type, 0, 1);
            new SequentialChannelStateReaderImpl(new TaskStateSnapshot()).readOutputData((ResultPartitionWriter[])new BufferWritingResultPartition[]{resultPartition}, true);
            resultPartition.emitRecord(ByteBuffer.wrap(dataAfterRecovery), 0);
            ResultSubpartitionView view = resultPartition.createSubpartitionView(0, (BufferAvailabilityListener)new NoOpBufferAvailablityListener());
            if (type != ResultPartitionType.PIPELINED_APPROXIMATE) {
                Assert.assertEquals((Object)Buffer.DataType.RECOVERY_COMPLETION, (Object)view.getNextBuffer().buffer().getDataType());
                Assert.assertNull((Object)view.getNextBuffer());
                view.resumeConsumption();
            }
            Assert.assertArrayEquals((byte[])dataAfterRecovery, (byte[])this.collectBytes(view.getNextBuffer().buffer()));
        }
        finally {
            networkBufferPool.destroy();
        }
    }

    private BufferWritingResultPartition buildResultPartition(NetworkBufferPool networkBufferPool, ResultPartitionType resultPartitionType, int index, int numberOfSubpartitions) throws IOException {
        ResultPartition resultPartition = new ResultPartitionBuilder().setResultPartitionIndex(index).setResultPartitionType(resultPartitionType).setNumberOfSubpartitions(numberOfSubpartitions).setBufferPoolFactory((SupplierWithException<BufferPool, IOException>)((SupplierWithException)() -> networkBufferPool.createBufferPool(numberOfSubpartitions, Integer.MAX_VALUE, numberOfSubpartitions, Integer.MAX_VALUE, 0))).build();
        resultPartition.setup();
        return (BufferWritingResultPartition)resultPartition;
    }

    private SingleInputGate buildGate(NetworkBufferPool networkBufferPool, int numberOfChannels) throws IOException {
        SingleInputGate gate = new SingleInputGateBuilder().setChannelFactory(InputChannelBuilder::buildRemoteRecoveredChannel).setBufferPoolFactory(networkBufferPool.createBufferPool(numberOfChannels, Integer.MAX_VALUE)).setSegmentProvider((MemorySegmentProvider)networkBufferPool).setNumberOfChannels(numberOfChannels).build();
        gate.setup();
        return gate;
    }

    private <T> byte[] collectBytes(SupplierWithException<Optional<T>, Exception> entrySupplier, Function<T, Buffer> bufferExtractor) throws Exception {
        ArrayList<Buffer> buffers = new ArrayList<Buffer>();
        Optional entry = (Optional)entrySupplier.get();
        while (entry.isPresent()) {
            entry.map(bufferExtractor).filter(buffer -> buffer.getDataType().isBuffer()).ifPresent(buffers::add);
            entry = (Optional)entrySupplier.get();
        }
        ByteBuffer result = ByteBuffer.wrap(new byte[buffers.stream().mapToInt(Buffer::getSize).sum()]);
        buffers.forEach(buffer -> {
            result.put(buffer.getNioBufferReadable());
            buffer.recycleBuffer();
        });
        return result.array();
    }

    private byte[] collectBytes(Buffer buffer) {
        ByteBuffer nioBufferReadable = buffer.getNioBufferReadable();
        byte[] buf = new byte[nioBufferReadable.capacity()];
        nioBufferReadable.get(buf);
        return buf;
    }

    private byte[] randomBytes(int size) {
        byte[] bytes = new byte[size];
        RANDOM.nextBytes(bytes);
        return bytes;
    }

    private ChannelStateWriter.ChannelStateWriteResult write(long checkpointId, Map<InputChannelInfo, byte[]> icMap, Map<ResultSubpartitionInfo, byte[]> rsMap, Map<ResultSubpartitionInfo, byte[]> rsFutureMap) throws Exception {
        int maxStateSize = ChannelPersistenceITCase.sizeOfBytes(icMap) + ChannelPersistenceITCase.sizeOfBytes(rsMap) + ChannelPersistenceITCase.sizeOfBytes(rsFutureMap) + 24;
        Map<InputChannelInfo, Buffer> icBuffers = this.wrapWithBuffers(icMap);
        Map<ResultSubpartitionInfo, Buffer> rsBuffers = this.wrapWithBuffers(rsMap);
        Map<ResultSubpartitionInfo, Buffer> rsFutureBuffers = this.wrapWithBuffers(rsFutureMap);
        try (ChannelStateWriterImpl writer = new ChannelStateWriterImpl(JOB_VERTEX_ID, "test", 0, (CheckpointStorage)new JobManagerCheckpointStorage(maxStateSize), new ChannelStateWriteRequestExecutorFactory(JOB_ID), 5);){
            writer.start(checkpointId, new CheckpointOptions((SnapshotType)CheckpointType.CHECKPOINT, new CheckpointStorageLocationReference("poly".getBytes(StandardCharsets.UTF_8))));
            for (Map.Entry<InputChannelInfo, Buffer> channelStateWriteResult : icBuffers.entrySet()) {
                writer.addInputData(checkpointId, channelStateWriteResult.getKey(), -2, CloseableIterator.ofElements(Buffer::recycleBuffer, (Object[])new Buffer[]{channelStateWriteResult.getValue()}));
            }
            writer.finishInput(checkpointId);
            for (Map.Entry<InputChannelInfo, Buffer> entry : rsFutureBuffers.entrySet()) {
                CompletableFuture<List<Buffer>> dataFuture = new CompletableFuture<List<Buffer>>();
                writer.addOutputDataFuture(checkpointId, (ResultSubpartitionInfo)entry.getKey(), -2, dataFuture);
                dataFuture.complete(Collections.singletonList(entry.getValue()));
            }
            for (Map.Entry<InputChannelInfo, Buffer> entry : rsBuffers.entrySet()) {
                writer.addOutputData(checkpointId, (ResultSubpartitionInfo)entry.getKey(), -2, new Buffer[]{entry.getValue()});
            }
            writer.finishOutput(checkpointId);
            ChannelStateWriter.ChannelStateWriteResult result = writer.getAndRemoveWriteResult(checkpointId);
            result.getResultSubpartitionStateHandles().join();
            ChannelStateWriter.ChannelStateWriteResult channelStateWriteResult = result;
            return channelStateWriteResult;
        }
    }

    private TaskStateSnapshot toTaskStateSnapshot(ChannelStateWriter.ChannelStateWriteResult t) throws Exception {
        return new TaskStateSnapshot(Collections.singletonMap(new OperatorID(), OperatorSubtaskState.builder().setInputChannelState(new StateObjectCollection((Collection)t.getInputChannelStateHandles().get())).setResultSubpartitionState(new StateObjectCollection((Collection)t.getResultSubpartitionStateHandles().get())).build()));
    }

    private static int sizeOfBytes(Map<?, byte[]> map) {
        return map.values().stream().mapToInt(d -> ((byte[])d).length).sum();
    }

    private <K> Map<K, Buffer> wrapWithBuffers(Map<K, byte[]> icMap) {
        return icMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ChannelPersistenceITCase.wrapWithBuffer((byte[])e.getValue())));
    }

    private static Buffer wrapWithBuffer(byte[] data) {
        NetworkBuffer buffer = new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment((int)data.length, null), FreeingBufferRecycler.INSTANCE);
        buffer.writeBytes(data);
        return buffer;
    }
}

