/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.state;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.DefaultOperatorStateBackendBuilder;
import org.apache.flink.runtime.state.OperatorStateBackend;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.PartitionableListState;
import org.apache.flink.runtime.state.SnapshotResult;
import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.testcontainers.utility.ThrowingFunction;

public class OperatorStateRestoreOperationTest {
    private static ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> createOperatorStateBackendFactory(ExecutionConfig cfg, CloseableRegistry cancelStreamRegistry, ClassLoader classLoader) {
        return handles -> new DefaultOperatorStateBackendBuilder(classLoader, cfg, false, handles, cancelStreamRegistry).build();
    }

    private static OperatorStateHandle createOperatorStateHandle(ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> operatorStateBackendFactory, Map<String, List<String>> listStates, Map<String, Map<String, String>> broadcastStates) throws Exception {
        try (OperatorStateBackend operatorStateBackend = (OperatorStateBackend)operatorStateBackendFactory.apply(Collections.emptyList());){
            PartitionableListState state;
            ListStateDescriptor descriptor;
            for (String stateName : listStates.keySet()) {
                descriptor = new ListStateDescriptor(stateName, String.class);
                state = (PartitionableListState)operatorStateBackend.getListState(descriptor);
                state.addAll(listStates.get(stateName));
            }
            for (String stateName : broadcastStates.keySet()) {
                descriptor = new MapStateDescriptor(stateName, String.class, String.class);
                state = operatorStateBackend.getBroadcastState((MapStateDescriptor)descriptor);
                state.putAll(broadcastStates.get(stateName));
            }
            SnapshotResult result = (SnapshotResult)operatorStateBackend.snapshot(1L, 1L, (CheckpointStreamFactory)new MemCheckpointStreamFactory(4096), CheckpointOptions.forCheckpointWithDefaultLocation()).get();
            OperatorStateHandle operatorStateHandle = (OperatorStateHandle)Objects.requireNonNull(result.getJobManagerOwnedSnapshot());
            return operatorStateHandle;
        }
    }

    private static void verifyOperatorStateHandle(ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> operatorStateBackendFactory, Collection<OperatorStateHandle> stateHandles, Map<String, List<String>> listStates, Map<String, Map<String, String>> broadcastStates) throws Exception {
        try (OperatorStateBackend operatorStateBackend = (OperatorStateBackend)operatorStateBackendFactory.apply(stateHandles);){
            PartitionableListState state;
            ListStateDescriptor descriptor;
            for (String stateName : listStates.keySet()) {
                descriptor = new ListStateDescriptor(stateName, String.class);
                state = (PartitionableListState)operatorStateBackend.getListState(descriptor);
                Assertions.assertThat((Iterable)state.get()).containsExactlyElementsOf((Iterable)listStates.get(stateName));
            }
            for (String stateName : listStates.keySet()) {
                descriptor = new ListStateDescriptor(stateName, String.class);
                state = (PartitionableListState)operatorStateBackend.getListState(descriptor);
                Assertions.assertThat((Iterable)state.get()).containsExactlyElementsOf((Iterable)listStates.get(stateName));
            }
            for (String stateName : broadcastStates.keySet()) {
                descriptor = new MapStateDescriptor(stateName, String.class, String.class);
                state = operatorStateBackend.getBroadcastState((MapStateDescriptor)descriptor);
                HashMap content = new HashMap();
                state.iterator().forEachRemaining(e -> {
                    String cfr_ignored_0 = (String)content.put(e.getKey(), e.getValue());
                });
                Assertions.assertThat(content).containsAllEntriesOf(broadcastStates.get(stateName));
            }
        }
    }

    @ParameterizedTest
    @ValueSource(booleans={true, false})
    void testRestoringMixedOperatorState(boolean snapshotCompressionEnabled) throws Exception {
        ExecutionConfig cfg = new ExecutionConfig();
        cfg.setUseSnapshotCompression(snapshotCompressionEnabled);
        ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> operatorStateBackendFactory = OperatorStateRestoreOperationTest.createOperatorStateBackendFactory(cfg, new CloseableRegistry(), this.getClass().getClassLoader());
        HashMap<String, List<String>> listStates = new HashMap<String, List<String>>();
        listStates.put("s1", Arrays.asList("foo1", "foo2", "foo3"));
        listStates.put("s2", Arrays.asList("bar1", "bar2", "bar3"));
        HashMap<String, Map<String, String>> broadcastStates = new HashMap<String, Map<String, String>>();
        broadcastStates.put("a1", Collections.singletonMap("foo", "bar"));
        broadcastStates.put("a2", Collections.singletonMap("bar", "foo"));
        OperatorStateHandle stateHandle = OperatorStateRestoreOperationTest.createOperatorStateHandle(operatorStateBackendFactory, listStates, broadcastStates);
        OperatorStateRestoreOperationTest.verifyOperatorStateHandle(operatorStateBackendFactory, Collections.singletonList(stateHandle), listStates, broadcastStates);
    }

    @ParameterizedTest
    @ValueSource(booleans={true, false})
    void testMergeOperatorState(boolean snapshotCompressionEnabled) throws Exception {
        ExecutionConfig cfg = new ExecutionConfig();
        cfg.setUseSnapshotCompression(snapshotCompressionEnabled);
        ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> operatorStateBackendFactory = OperatorStateRestoreOperationTest.createOperatorStateBackendFactory(cfg, new CloseableRegistry(), this.getClass().getClassLoader());
        HashMap<String, List<String>> firstListStates = new HashMap<String, List<String>>();
        firstListStates.put("s1", Arrays.asList("foo1", "foo2", "foo3"));
        firstListStates.put("s2", Arrays.asList("bar1", "bar2", "bar3"));
        HashMap<String, List<String>> secondListStates = new HashMap<String, List<String>>();
        secondListStates.put("s1", Arrays.asList("foo4", "foo5", "foo6"));
        secondListStates.put("s2", Arrays.asList("bar1", "bar2", "bar3"));
        OperatorStateHandle firstStateHandle = OperatorStateRestoreOperationTest.createOperatorStateHandle(operatorStateBackendFactory, firstListStates, Collections.emptyMap());
        OperatorStateHandle secondStateHandle = OperatorStateRestoreOperationTest.createOperatorStateHandle(operatorStateBackendFactory, firstListStates, Collections.emptyMap());
        HashMap<String, List<String>> mergedListStates = new HashMap<String, List<String>>();
        for (String stateName : firstListStates.keySet()) {
            mergedListStates.computeIfAbsent(stateName, k -> new ArrayList()).addAll((Collection)firstListStates.get(stateName));
        }
        for (String stateName : secondListStates.keySet()) {
            mergedListStates.computeIfAbsent(stateName, k -> new ArrayList()).addAll((Collection)firstListStates.get(stateName));
        }
        OperatorStateRestoreOperationTest.verifyOperatorStateHandle(operatorStateBackendFactory, Arrays.asList(firstStateHandle, secondStateHandle), mergedListStates, Collections.emptyMap());
    }

    @ParameterizedTest
    @ValueSource(booleans={true, false})
    void testEmptyPartitionedOperatorState(boolean snapshotCompressionEnabled) throws Exception {
        ExecutionConfig cfg = new ExecutionConfig();
        cfg.setUseSnapshotCompression(snapshotCompressionEnabled);
        ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> operatorStateBackendFactory = OperatorStateRestoreOperationTest.createOperatorStateBackendFactory(cfg, new CloseableRegistry(), this.getClass().getClassLoader());
        HashMap<String, List<String>> listStates = new HashMap<String, List<String>>();
        listStates.put("bufferState", Collections.emptyList());
        listStates.put("offsetState", Collections.singletonList("foo"));
        HashMap<String, Map<String, String>> broadcastStates = new HashMap<String, Map<String, String>>();
        broadcastStates.put("whateverState", Collections.emptyMap());
        OperatorStateHandle stateHandle = OperatorStateRestoreOperationTest.createOperatorStateHandle(operatorStateBackendFactory, listStates, broadcastStates);
        OperatorStateRestoreOperationTest.verifyOperatorStateHandle(operatorStateBackendFactory, Collections.singletonList(stateHandle), listStates, broadcastStates);
    }

    @ParameterizedTest
    @ValueSource(booleans={true, false})
    void testRepartitionOperatorState(boolean snapshotCompressionEnabled) throws Exception {
        ExecutionConfig cfg = new ExecutionConfig();
        cfg.setUseSnapshotCompression(snapshotCompressionEnabled);
        ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> operatorStateBackendFactory = OperatorStateRestoreOperationTest.createOperatorStateBackendFactory(cfg, new CloseableRegistry(), this.getClass().getClassLoader());
        HashMap<String, List<String>> listStates = new HashMap<String, List<String>>();
        listStates.put("bufferState", IntStream.range(0, 10).mapToObj(idx -> "foo" + idx).collect(Collectors.toList()));
        listStates.put("offsetState", IntStream.range(0, 10).mapToObj(idx -> "bar" + idx).collect(Collectors.toList()));
        OperatorStateHandle stateHandle = OperatorStateRestoreOperationTest.createOperatorStateHandle(operatorStateBackendFactory, listStates, Collections.emptyMap());
        for (int newParallelism : Arrays.asList(1, 2, 5, 10)) {
            RoundRobinOperatorStateRepartitioner partitioner = new RoundRobinOperatorStateRepartitioner();
            List repartitioned = partitioner.repartitionState(Collections.singletonList(Collections.singletonList(stateHandle)), 1, newParallelism);
            for (int idx2 = 0; idx2 < newParallelism; ++idx2) {
                OperatorStateRestoreOperationTest.verifyOperatorStateHandle(operatorStateBackendFactory, (Collection)repartitioned.get(idx2), OperatorStateRestoreOperationTest.getExpectedSplit(listStates, newParallelism, idx2), Collections.emptyMap());
            }
        }
    }

    private static Map<String, List<String>> getExpectedSplit(Map<String, List<String>> states, int newParallelism, int idx) {
        HashMap<String, List<String>> newStates = new HashMap<String, List<String>>();
        for (String stateName : states.keySet()) {
            int stateSize = states.get(stateName).size();
            newStates.put(stateName, states.get(stateName).subList(idx * stateSize / newParallelism, (idx + 1) * stateSize / newParallelism));
        }
        return newStates;
    }
}

