/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint.channel;

import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.stream.IntStream;
import org.apache.flink.core.fs.FileSystem;
import org.apache.flink.core.fs.Path;
import org.apache.flink.core.fs.local.LocalFileSystem;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateCheckpointWriter;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateSerializer;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateSerializerImpl;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteResultUtil;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.checkpoint.channel.CloseExceptionOutputStream;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.SubtaskID;
import org.apache.flink.runtime.checkpoint.channel.TestException;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.CheckpointStateOutputStream;
import org.apache.flink.runtime.state.CheckpointedStateScope;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.filesystem.FsCheckpointStreamFactory;
import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
import org.apache.flink.testutils.junit.utils.TempDirUtils;
import org.apache.flink.util.function.RunnableWithException;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;

class ChannelStateCheckpointWriterTest {
    private static final RunnableWithException NO_OP_RUNNABLE = () -> {};
    private final Random random = new Random();
    private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
    private static final int SUBTASK_INDEX = 0;
    private static final SubtaskID SUBTASK_ID = SubtaskID.of((JobVertexID)JOB_VERTEX_ID, (int)0);
    @TempDir
    private java.nio.file.Path temporaryFolder;

    ChannelStateCheckpointWriterTest() {
    }

    @Test
    void testFileHandleSize() throws Exception {
        int numChannels = 3;
        int numWritesPerChannel = 4;
        int numBytesPerWrite = 5;
        ChannelStateWriter.ChannelStateWriteResult result = new ChannelStateWriter.ChannelStateWriteResult();
        ChannelStateCheckpointWriter writer = this.createWriter(result, (CheckpointStateOutputStream)new FsCheckpointStreamFactory((FileSystem)LocalFileSystem.getSharedInstance(), Path.fromLocalFile((File)TempDirUtils.newFolder((java.nio.file.Path)this.temporaryFolder, (String[])new String[]{"checkpointsDir"})), Path.fromLocalFile((File)TempDirUtils.newFolder((java.nio.file.Path)this.temporaryFolder, (String[])new String[]{"sharedStateDir"})), numBytesPerWrite - 1, numBytesPerWrite - 1).createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE));
        InputChannelInfo[] channels = (InputChannelInfo[])IntStream.range(0, numChannels).mapToObj(i -> new InputChannelInfo(0, i)).toArray(InputChannelInfo[]::new);
        for (int call = 0; call < numWritesPerChannel; ++call) {
            for (int channel = 0; channel < numChannels; ++channel) {
                this.write(writer, channels[channel], this.getData(numBytesPerWrite));
            }
        }
        writer.completeInput(JOB_VERTEX_ID, 0);
        writer.completeOutput(JOB_VERTEX_ID, 0);
        for (InputChannelStateHandle handle : (Collection)result.inputChannelStateHandles.get()) {
            Assertions.assertThat((long)handle.getStateSize()).isEqualTo((long)((4 + numBytesPerWrite) * numWritesPerChannel));
        }
    }

    @Test
    void testSmallFilesNotWritten() throws Exception {
        int threshold = 100;
        File checkpointsDir = TempDirUtils.newFolder((java.nio.file.Path)this.temporaryFolder, (String[])new String[]{"checkpointsDir"});
        File sharedStateDir = TempDirUtils.newFolder((java.nio.file.Path)this.temporaryFolder, (String[])new String[]{"sharedStateDir"});
        FsCheckpointStreamFactory checkpointStreamFactory = new FsCheckpointStreamFactory((FileSystem)LocalFileSystem.getSharedInstance(), Path.fromLocalFile((File)checkpointsDir), Path.fromLocalFile((File)sharedStateDir), threshold, threshold);
        ChannelStateWriter.ChannelStateWriteResult result = new ChannelStateWriter.ChannelStateWriteResult();
        ChannelStateCheckpointWriter writer = this.createWriter(result, (CheckpointStateOutputStream)checkpointStreamFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE));
        NetworkBuffer buffer = new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment((int)(threshold / 2)), FreeingBufferRecycler.INSTANCE);
        writer.writeInput(JOB_VERTEX_ID, 0, new InputChannelInfo(1, 2), (Buffer)buffer);
        writer.completeOutput(JOB_VERTEX_ID, 0);
        writer.completeInput(JOB_VERTEX_ID, 0);
        Assertions.assertThat((boolean)result.isDone()).isTrue();
        Assertions.assertThat((File)checkpointsDir).isEmptyDirectory();
        Assertions.assertThat((File)sharedStateDir).isEmptyDirectory();
    }

    @Test
    void testEmptyState() throws Exception {
        MemCheckpointStreamFactory.MemoryCheckpointOutputStream stream = new MemCheckpointStreamFactory.MemoryCheckpointOutputStream(1000){

            public StreamStateHandle closeAndGetHandle() {
                Assertions.fail((String)"closeAndGetHandle shouldn't be called for empty channel state");
                return null;
            }
        };
        ChannelStateCheckpointWriter writer = this.createWriter(new ChannelStateWriter.ChannelStateWriteResult(), (CheckpointStateOutputStream)stream);
        writer.completeOutput(JOB_VERTEX_ID, 0);
        writer.completeInput(JOB_VERTEX_ID, 0);
        Assertions.assertThat((boolean)stream.isClosed()).isTrue();
    }

    @Test
    void testRecyclingBuffers() {
        ChannelStateCheckpointWriter writer = this.createWriter(new ChannelStateWriter.ChannelStateWriteResult());
        NetworkBuffer buffer = new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment((int)10), FreeingBufferRecycler.INSTANCE);
        writer.writeInput(JOB_VERTEX_ID, 0, new InputChannelInfo(1, 2), (Buffer)buffer);
        Assertions.assertThat((boolean)buffer.isRecycled()).isTrue();
    }

    @Test
    void testFlush() throws Exception {
        class FlushRecorder
        extends DataOutputStream {
            private boolean flushed;

            FlushRecorder() {
                super(new ByteArrayOutputStream());
                this.flushed = false;
            }

            @Override
            public void flush() throws IOException {
                this.flushed = true;
                super.flush();
            }
        }
        FlushRecorder dataStream = new FlushRecorder();
        ChannelStateCheckpointWriter writer = new ChannelStateCheckpointWriter(Collections.singleton(SUBTASK_ID), 1L, (ChannelStateSerializer)new ChannelStateSerializerImpl(), NO_OP_RUNNABLE, (CheckpointStateOutputStream)new MemCheckpointStreamFactory.MemoryCheckpointOutputStream(42), (DataOutputStream)dataStream);
        writer.registerSubtaskResult(SUBTASK_ID, new ChannelStateWriter.ChannelStateWriteResult());
        writer.completeInput(JOB_VERTEX_ID, 0);
        writer.completeOutput(JOB_VERTEX_ID, 0);
        Assertions.assertThat((boolean)dataStream.flushed).isTrue();
    }

    @Test
    void testResultCompletion() throws Exception {
        for (int maxSubtasksPerChannelStateFile = 1; maxSubtasksPerChannelStateFile < 10; ++maxSubtasksPerChannelStateFile) {
            this.testMultiTaskCompletionAndAssertResult(maxSubtasksPerChannelStateFile);
        }
    }

    private void testMultiTaskCompletionAndAssertResult(int maxSubtasksPerChannelStateFile) throws Exception {
        HashMap<SubtaskID, ChannelStateWriter.ChannelStateWriteResult> subtasks = new HashMap<SubtaskID, ChannelStateWriter.ChannelStateWriteResult>();
        for (int i = 0; i < maxSubtasksPerChannelStateFile; ++i) {
            subtasks.put(SubtaskID.of((JobVertexID)new JobVertexID(), (int)i), new ChannelStateWriter.ChannelStateWriteResult());
        }
        MemCheckpointStreamFactory.MemoryCheckpointOutputStream stream = new MemCheckpointStreamFactory.MemoryCheckpointOutputStream(1000);
        ChannelStateCheckpointWriter writer = this.createWriter((CheckpointStateOutputStream)stream, subtasks.keySet());
        for (Map.Entry entry : subtasks.entrySet()) {
            writer.registerSubtaskResult((SubtaskID)entry.getKey(), (ChannelStateWriter.ChannelStateWriteResult)entry.getValue());
        }
        for (SubtaskID subtaskID : subtasks.keySet()) {
            ChannelStateWriteResultUtil.assertAllSubtaskNotDone(subtasks.values());
            Assertions.assertThat((boolean)stream.isClosed()).isFalse();
            writer.completeInput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
            ChannelStateWriteResultUtil.assertAllSubtaskNotDone(subtasks.values());
            Assertions.assertThat((boolean)stream.isClosed()).isFalse();
            writer.completeOutput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
        }
        Assertions.assertThat((boolean)stream.isClosed()).isTrue();
        ChannelStateWriteResultUtil.assertAllSubtaskDoneNormally(subtasks.values());
    }

    @Test
    void testTaskUnregister() throws Exception {
        this.testTaskUnregisterAndAssertResult(2);
        this.testTaskUnregisterAndAssertResult(3);
        this.testTaskUnregisterAndAssertResult(5);
        this.testTaskUnregisterAndAssertResult(10);
    }

    private void testTaskUnregisterAndAssertResult(int maxSubtasksPerChannelStateFile) throws Exception {
        HashMap<SubtaskID, ChannelStateWriter.ChannelStateWriteResult> subtasks = new HashMap<SubtaskID, ChannelStateWriter.ChannelStateWriteResult>();
        for (int i = 0; i < maxSubtasksPerChannelStateFile; ++i) {
            subtasks.put(SubtaskID.of((JobVertexID)new JobVertexID(), (int)i), new ChannelStateWriter.ChannelStateWriteResult());
        }
        MemCheckpointStreamFactory.MemoryCheckpointOutputStream stream = new MemCheckpointStreamFactory.MemoryCheckpointOutputStream(1000);
        ChannelStateCheckpointWriter writer = this.createWriter((CheckpointStateOutputStream)stream, subtasks.keySet());
        SubtaskID unregisterSubtask = null;
        Iterator iterator = subtasks.entrySet().iterator();
        while (iterator.hasNext()) {
            Map.Entry entry = iterator.next();
            if (unregisterSubtask == null) {
                unregisterSubtask = (SubtaskID)entry.getKey();
                iterator.remove();
                continue;
            }
            writer.registerSubtaskResult((SubtaskID)entry.getKey(), (ChannelStateWriter.ChannelStateWriteResult)entry.getValue());
        }
        for (SubtaskID subtaskID : subtasks.keySet()) {
            writer.completeInput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
            writer.completeOutput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
        }
        ChannelStateWriteResultUtil.assertAllSubtaskNotDone(subtasks.values());
        Assertions.assertThat((boolean)stream.isClosed()).isFalse();
        assert (unregisterSubtask != null);
        writer.releaseSubtask(unregisterSubtask);
        Assertions.assertThat((boolean)stream.isClosed()).isTrue();
        ChannelStateWriteResultUtil.assertAllSubtaskDoneNormally(subtasks.values());
    }

    @Test
    void testTaskFailThenCompleteOtherTask() {
        this.testTaskFailAfterAllTaskRegisteredAndAssertResult(2);
        this.testTaskFailAfterAllTaskRegisteredAndAssertResult(3);
        this.testTaskFailAfterAllTaskRegisteredAndAssertResult(5);
        this.testTaskFailAfterAllTaskRegisteredAndAssertResult(10);
    }

    private void testTaskFailAfterAllTaskRegisteredAndAssertResult(int maxSubtasksPerChannelStateFile) {
        HashMap<SubtaskID, ChannelStateWriter.ChannelStateWriteResult> subtasks = new HashMap<SubtaskID, ChannelStateWriter.ChannelStateWriteResult>();
        for (int i = 0; i < maxSubtasksPerChannelStateFile; ++i) {
            subtasks.put(SubtaskID.of((JobVertexID)new JobVertexID(), (int)i), new ChannelStateWriter.ChannelStateWriteResult());
        }
        MemCheckpointStreamFactory.MemoryCheckpointOutputStream stream = new MemCheckpointStreamFactory.MemoryCheckpointOutputStream(1000);
        ChannelStateCheckpointWriter writer = this.createWriter((CheckpointStateOutputStream)stream, subtasks.keySet());
        SubtaskID firstSubtask = null;
        for (Map.Entry entry : subtasks.entrySet()) {
            if (firstSubtask == null) {
                firstSubtask = (SubtaskID)entry.getKey();
            }
            writer.registerSubtaskResult((SubtaskID)entry.getKey(), (ChannelStateWriter.ChannelStateWriteResult)entry.getValue());
        }
        Assertions.assertThat((boolean)stream.isClosed()).isFalse();
        assert (firstSubtask != null);
        writer.fail(firstSubtask.getJobVertexID(), firstSubtask.getSubtaskIndex(), (Throwable)new TestException());
        Assertions.assertThat((boolean)stream.isClosed()).isTrue();
        for (Map.Entry entry : subtasks.entrySet()) {
            if (firstSubtask.equals(entry.getKey())) {
                ChannelStateWriteResultUtil.assertHasSpecialCause((ChannelStateWriter.ChannelStateWriteResult)entry.getValue(), TestException.class);
                continue;
            }
            ChannelStateWriteResultUtil.assertCheckpointFailureReason((ChannelStateWriter.ChannelStateWriteResult)entry.getValue(), CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION);
        }
    }

    @Test
    void testCloseGetHandleThrowException() throws Exception {
        HashMap<SubtaskID, ChannelStateWriter.ChannelStateWriteResult> subtasks = new HashMap<SubtaskID, ChannelStateWriter.ChannelStateWriteResult>();
        for (int i = 0; i < 5; ++i) {
            subtasks.put(SubtaskID.of((JobVertexID)new JobVertexID(), (int)i), new ChannelStateWriter.ChannelStateWriteResult());
        }
        CloseExceptionOutputStream stream = new CloseExceptionOutputStream();
        ChannelStateCheckpointWriter writer = this.createWriter((CheckpointStateOutputStream)stream, subtasks.keySet());
        for (Map.Entry entry : subtasks.entrySet()) {
            SubtaskID subtaskID = (SubtaskID)entry.getKey();
            writer.registerSubtaskResult(subtaskID, (ChannelStateWriter.ChannelStateWriteResult)entry.getValue());
            NetworkBuffer buffer = new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment((int)10), FreeingBufferRecycler.INSTANCE);
            writer.writeInput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex(), new InputChannelInfo(1, 2), (Buffer)buffer);
        }
        for (SubtaskID subtaskID : subtasks.keySet()) {
            ChannelStateWriteResultUtil.assertAllSubtaskNotDone(subtasks.values());
            Assertions.assertThat((boolean)stream.isClosed()).isFalse();
            writer.completeInput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
            ChannelStateWriteResultUtil.assertAllSubtaskNotDone(subtasks.values());
            Assertions.assertThat((boolean)stream.isClosed()).isFalse();
            writer.completeOutput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
        }
        Assertions.assertThat((boolean)stream.isClosed()).isTrue();
        for (Map.Entry entry : subtasks.entrySet()) {
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> {
                Collection cfr_ignored_0 = (Collection)((ChannelStateWriter.ChannelStateWriteResult)entry.getValue()).getInputChannelStateHandles().get();
            }).cause().isInstanceOf(IOException.class)).hasMessage("Test closeAndGetHandle exception.");
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> {
                Collection cfr_ignored_0 = (Collection)((ChannelStateWriter.ChannelStateWriteResult)entry.getValue()).getResultSubpartitionStateHandles().get();
            }).cause().isInstanceOf(IOException.class)).hasMessage("Test closeAndGetHandle exception.");
        }
    }

    @Test
    void testRegisterSubtaskAfterWriterDone() {
        HashMap<SubtaskID, ChannelStateWriter.ChannelStateWriteResult> subtasks = new HashMap<SubtaskID, ChannelStateWriter.ChannelStateWriteResult>();
        SubtaskID subtask0 = SubtaskID.of((JobVertexID)JOB_VERTEX_ID, (int)0);
        SubtaskID subtask1 = SubtaskID.of((JobVertexID)JOB_VERTEX_ID, (int)1);
        subtasks.put(subtask0, new ChannelStateWriter.ChannelStateWriteResult());
        subtasks.put(subtask1, new ChannelStateWriter.ChannelStateWriteResult());
        MemCheckpointStreamFactory.MemoryCheckpointOutputStream stream = new MemCheckpointStreamFactory.MemoryCheckpointOutputStream(1000);
        ChannelStateCheckpointWriter writer = this.createWriter((CheckpointStateOutputStream)stream, subtasks.keySet());
        writer.fail(new JobVertexID(), 0, (Throwable)new TestException());
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> writer.registerSubtaskResult(subtask0, new ChannelStateWriter.ChannelStateWriteResult())).isInstanceOf(IllegalStateException.class)).hasMessage("The write is done.");
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> writer.registerSubtaskResult(subtask1, new ChannelStateWriter.ChannelStateWriteResult())).isInstanceOf(IllegalStateException.class)).hasMessage("The write is done.");
    }

    @Test
    void testRecordingOffsets() throws Exception {
        HashMap<InputChannelInfo, Integer> offsetCounts = new HashMap<InputChannelInfo, Integer>();
        offsetCounts.put(new InputChannelInfo(1, 1), 1);
        offsetCounts.put(new InputChannelInfo(1, 2), 2);
        offsetCounts.put(new InputChannelInfo(1, 3), 5);
        int numBytes = 100;
        ChannelStateWriter.ChannelStateWriteResult result = new ChannelStateWriter.ChannelStateWriteResult();
        ChannelStateCheckpointWriter writer = this.createWriter(result);
        for (Map.Entry e : offsetCounts.entrySet()) {
            for (int i = 0; i < (Integer)e.getValue(); ++i) {
                this.write(writer, (InputChannelInfo)e.getKey(), this.getData(numBytes));
            }
        }
        writer.completeInput(JOB_VERTEX_ID, 0);
        writer.completeOutput(JOB_VERTEX_ID, 0);
        for (InputChannelStateHandle handle : (Collection)result.inputChannelStateHandles.get()) {
            int headerSize = 4;
            int lengthSize = 4;
            Assertions.assertThat((List)handle.getOffsets()).isEqualTo(Collections.singletonList(Long.valueOf(headerSize)));
            Assertions.assertThat((long)handle.getDelegate().getStateSize()).isEqualTo((long)(headerSize + lengthSize + numBytes * (Integer)offsetCounts.remove(handle.getInfo())));
        }
        Assertions.assertThat(offsetCounts).isEmpty();
    }

    private byte[] getData(int len) {
        byte[] bytes = new byte[len];
        this.random.nextBytes(bytes);
        return bytes;
    }

    private void write(ChannelStateCheckpointWriter writer, InputChannelInfo channelInfo, byte[] data) {
        MemorySegment segment = MemorySegmentFactory.wrap((byte[])data);
        NetworkBuffer buffer = new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE, Buffer.DataType.DATA_BUFFER, segment.size());
        writer.writeInput(JOB_VERTEX_ID, 0, channelInfo, (Buffer)buffer);
    }

    private ChannelStateCheckpointWriter createWriter(ChannelStateWriter.ChannelStateWriteResult result) {
        return this.createWriter(result, (CheckpointStateOutputStream)new MemCheckpointStreamFactory.MemoryCheckpointOutputStream(1000));
    }

    private ChannelStateCheckpointWriter createWriter(ChannelStateWriter.ChannelStateWriteResult result, CheckpointStateOutputStream stream) {
        ChannelStateCheckpointWriter writer = this.createWriter(stream, Collections.singleton(SUBTASK_ID));
        writer.registerSubtaskResult(SUBTASK_ID, result);
        return writer;
    }

    private ChannelStateCheckpointWriter createWriter(CheckpointStateOutputStream stream, Set<SubtaskID> subtasks) {
        return new ChannelStateCheckpointWriter(subtasks, 1L, stream, (ChannelStateSerializer)new ChannelStateSerializerImpl(), NO_OP_RUNNABLE);
    }
}

