/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.streaming.runtime.tasks;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.runtime.state.TaskStateManager;
import org.apache.flink.runtime.state.TestTaskStateManager;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
import org.apache.flink.streaming.runtime.tasks.OneInputStreamTaskTestHarness;
import org.apache.flink.streaming.runtime.tasks.StreamMockEnvironment;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.streaming.util.TestHarnessUtil;
import org.apache.flink.util.TestLogger;
import org.apache.flink.util.function.FunctionWithException;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class RestoreStreamTaskTest
extends TestLogger {
    private static final Set<OperatorID> RESTORED_OPERATORS = ConcurrentHashMap.newKeySet();

    @Before
    public void setup() {
        RESTORED_OPERATORS.clear();
    }

    @Test
    public void testRestore() throws Exception {
        OperatorID headOperatorID = new OperatorID(42L, 42L);
        OperatorID tailOperatorID = new OperatorID(44L, 44L);
        JobManagerTaskRestore restore = this.createRunAndCheckpointOperatorChain(headOperatorID, new CounterOperator(), tailOperatorID, new CounterOperator(), Optional.empty());
        TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot();
        Assert.assertEquals((long)2L, (long)stateHandles.getSubtaskStateMappings().size());
        this.createRunAndCheckpointOperatorChain(headOperatorID, new CounterOperator(), tailOperatorID, new CounterOperator(), Optional.of(restore));
        Assert.assertEquals(new HashSet<OperatorID>(Arrays.asList(headOperatorID, tailOperatorID)), RESTORED_OPERATORS);
    }

    @Test
    public void testRestoreHeadWithNewId() throws Exception {
        OperatorID tailOperatorID = new OperatorID(44L, 44L);
        JobManagerTaskRestore restore = this.createRunAndCheckpointOperatorChain(new OperatorID(42L, 42L), new CounterOperator(), tailOperatorID, new CounterOperator(), Optional.empty());
        TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot();
        Assert.assertEquals((long)2L, (long)stateHandles.getSubtaskStateMappings().size());
        this.createRunAndCheckpointOperatorChain(new OperatorID(4242L, 4242L), new CounterOperator(), tailOperatorID, new CounterOperator(), Optional.of(restore));
        Assert.assertEquals(Collections.singleton(tailOperatorID), RESTORED_OPERATORS);
    }

    @Test
    public void testRestoreTailWithNewId() throws Exception {
        OperatorID headOperatorID = new OperatorID(42L, 42L);
        JobManagerTaskRestore restore = this.createRunAndCheckpointOperatorChain(headOperatorID, new CounterOperator(), new OperatorID(44L, 44L), new CounterOperator(), Optional.empty());
        TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot();
        Assert.assertEquals((long)2L, (long)stateHandles.getSubtaskStateMappings().size());
        this.createRunAndCheckpointOperatorChain(headOperatorID, new CounterOperator(), new OperatorID(4444L, 4444L), new CounterOperator(), Optional.of(restore));
        Assert.assertEquals(Collections.singleton(headOperatorID), RESTORED_OPERATORS);
    }

    @Test
    public void testRestoreAfterScaleUp() throws Exception {
        OperatorID headOperatorID = new OperatorID(42L, 42L);
        OperatorID tailOperatorID = new OperatorID(44L, 44L);
        JobManagerTaskRestore restore = this.createRunAndCheckpointOperatorChain(headOperatorID, new CounterOperator(), tailOperatorID, new CounterOperator(), Optional.empty());
        TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot();
        Assert.assertEquals((long)2L, (long)stateHandles.getSubtaskStateMappings().size());
        OperatorSubtaskState emptyHeadOperatorState = OperatorSubtaskState.builder().build();
        stateHandles.putSubtaskStateByOperatorID(headOperatorID, emptyHeadOperatorState);
        this.createRunAndCheckpointOperatorChain(headOperatorID, new CounterOperator(), tailOperatorID, new CounterOperator(), Optional.of(restore));
        Assert.assertEquals(new HashSet<OperatorID>(Arrays.asList(headOperatorID, tailOperatorID)), RESTORED_OPERATORS);
    }

    @Test
    public void testRestoreWithoutState() throws Exception {
        OperatorID headOperatorID = new OperatorID(42L, 42L);
        OperatorID tailOperatorID = new OperatorID(44L, 44L);
        JobManagerTaskRestore restore = this.createRunAndCheckpointOperatorChain(headOperatorID, new StatelessOperator(), tailOperatorID, new CounterOperator(), Optional.empty());
        TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot();
        Assert.assertEquals((long)2L, (long)stateHandles.getSubtaskStateMappings().size());
        this.createRunAndCheckpointOperatorChain(headOperatorID, new StatelessOperator(), tailOperatorID, new CounterOperator(), Optional.of(restore));
        Assert.assertEquals(new HashSet<OperatorID>(Arrays.asList(headOperatorID, tailOperatorID)), RESTORED_OPERATORS);
    }

    private JobManagerTaskRestore createRunAndCheckpointOperatorChain(OperatorID headId, OneInputStreamOperator<String, String> headOperator, OperatorID tailId, OneInputStreamOperator<String, String> tailOperator, Optional<JobManagerTaskRestore> restore) throws Exception {
        OneInputStreamTaskTestHarness<String, String> testHarness = new OneInputStreamTaskTestHarness<String, String>((FunctionWithException<Environment, StreamTask<String, ?>, Exception>)((FunctionWithException)OneInputStreamTask::new), 1, 1, (TypeInformation<String>)BasicTypeInfo.STRING_TYPE_INFO, (TypeInformation<String>)BasicTypeInfo.STRING_TYPE_INFO);
        testHarness.setupOperatorChain(headId, (StreamOperator<?>)headOperator).chain(tailId, tailOperator, StringSerializer.INSTANCE).finish();
        if (restore.isPresent()) {
            JobManagerTaskRestore taskRestore = restore.get();
            testHarness.setTaskStateSnapshot(taskRestore.getRestoreCheckpointId(), taskRestore.getTaskStateSnapshot());
        }
        StreamMockEnvironment environment = new StreamMockEnvironment(testHarness.jobConfig, testHarness.taskConfig, testHarness.executionConfig, testHarness.memorySize, new MockInputSplitProvider(), testHarness.bufferSize, (TaskStateManager)testHarness.taskStateManager);
        testHarness.invoke(environment);
        testHarness.waitForTaskRunning();
        OneInputStreamTask streamTask = testHarness.getTask();
        this.processRecords(testHarness);
        this.triggerCheckpoint(testHarness, streamTask);
        TestTaskStateManager taskStateManager = testHarness.taskStateManager;
        JobManagerTaskRestore jobManagerTaskRestore = new JobManagerTaskRestore(taskStateManager.getReportedCheckpointId(), taskStateManager.getLastJobManagerTaskStateSnapshot());
        testHarness.endInput();
        testHarness.waitForTaskCompletion();
        return jobManagerTaskRestore;
    }

    private void triggerCheckpoint(OneInputStreamTaskTestHarness<String, String> testHarness, OneInputStreamTask<String, String> streamTask) throws Exception {
        long checkpointId = 1L;
        CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 1L);
        testHarness.taskStateManager.setWaitForReportLatch(new OneShotLatch());
        streamTask.triggerCheckpointAsync(checkpointMetaData, CheckpointOptions.forCheckpointWithDefaultLocation());
        testHarness.taskStateManager.getWaitForReportLatch().await();
        long reportedCheckpointId = testHarness.taskStateManager.getReportedCheckpointId();
        Assert.assertEquals((long)checkpointId, (long)reportedCheckpointId);
    }

    private void processRecords(OneInputStreamTaskTestHarness<String, String> testHarness) throws Exception {
        ConcurrentLinkedQueue<StreamRecord> expectedOutput = new ConcurrentLinkedQueue<StreamRecord>();
        testHarness.processElement(new StreamRecord((Object)"10"), 0, 0);
        testHarness.processElement(new StreamRecord((Object)"20"), 0, 0);
        testHarness.processElement(new StreamRecord((Object)"30"), 0, 0);
        testHarness.waitForInputProcessing();
        expectedOutput.add(new StreamRecord((Object)"10"));
        expectedOutput.add(new StreamRecord((Object)"20"));
        expectedOutput.add(new StreamRecord((Object)"30"));
        TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
    }

    private static class StatelessOperator
    extends RestoreWatchOperator<String, String> {
        private static final long serialVersionUID = 2048954179291813244L;

        private StatelessOperator() {
        }

        public void processElement(StreamRecord<String> element) throws Exception {
            this.output.collect(element);
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
        }
    }

    private static class CounterOperator
    extends RestoreWatchOperator<String, String> {
        private static final long serialVersionUID = 2048954179291813243L;
        private ListState<Long> counterState;
        private long counter = 0L;

        private CounterOperator() {
        }

        public void processElement(StreamRecord<String> element) throws Exception {
            ++this.counter;
            this.output.collect(element);
        }

        @Override
        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.counterState = context.getOperatorStateStore().getListState(new ListStateDescriptor("counter-state", (TypeSerializer)LongSerializer.INSTANCE));
            if (context.isRestored()) {
                for (Long value : (Iterable)this.counterState.get()) {
                    this.counter += value.longValue();
                }
                this.counterState.clear();
            }
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            this.counterState.add((Object)this.counter);
        }
    }

    private static abstract class RestoreWatchOperator<IN, OUT>
    extends AbstractStreamOperator<OUT>
    implements OneInputStreamOperator<IN, OUT> {
        private RestoreWatchOperator() {
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            if (context.isRestored()) {
                RESTORED_OPERATORS.add(this.getOperatorID());
            }
        }
    }
}

