/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.OperatorStreamStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.util.CollectionUtil;
import org.apache.flink.util.Preconditions;

@Internal
public class RoundRobinOperatorStateRepartitioner
implements OperatorStateRepartitioner<OperatorStateHandle> {
    public static final OperatorStateRepartitioner<OperatorStateHandle> INSTANCE = new RoundRobinOperatorStateRepartitioner();
    private static final boolean OPTIMIZE_MEMORY_USE = false;

    @Override
    public List<List<OperatorStateHandle>> repartitionState(List<List<OperatorStateHandle>> previousParallelSubtaskStates, int oldParallelism, int newParallelism) {
        List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList;
        Preconditions.checkNotNull(previousParallelSubtaskStates);
        Preconditions.checkArgument((newParallelism > 0 ? 1 : 0) != 0);
        Preconditions.checkArgument((previousParallelSubtaskStates.size() == oldParallelism ? 1 : 0) != 0, (Object)"This method still depends on the order of the new and old operators");
        ArrayList<List<OperatorStateHandle>> result = new ArrayList<List<OperatorStateHandle>>(newParallelism);
        if (newParallelism == oldParallelism) {
            Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> unionStates = this.collectUnionStates(previousParallelSubtaskStates);
            Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> partlyFinishedBroadcastStates = this.collectPartlyFinishedBroadcastStates(previousParallelSubtaskStates);
            if (unionStates.isEmpty() && partlyFinishedBroadcastStates.isEmpty()) {
                return previousParallelSubtaskStates;
            }
            mergeMapList = this.initMergeMapList(previousParallelSubtaskStates);
            this.repartitionUnionState(unionStates, mergeMapList);
            this.repartitionBroadcastState(partlyFinishedBroadcastStates, mergeMapList);
        } else {
            GroupByStateNameResults nameToStateByMode = this.groupByStateMode(previousParallelSubtaskStates);
            mergeMapList = this.repartition(nameToStateByMode, newParallelism);
        }
        for (int i = 0; i < mergeMapList.size(); ++i) {
            result.add(i, new ArrayList<OperatorStateHandle>(mergeMapList.get(i).values()));
        }
        return result;
    }

    private List<Map<StreamStateHandle, OperatorStateHandle>> initMergeMapList(List<List<OperatorStateHandle>> parallelSubtaskStates) {
        int parallelism = parallelSubtaskStates.size();
        ArrayList<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList = new ArrayList<Map<StreamStateHandle, OperatorStateHandle>>(parallelism);
        for (List<OperatorStateHandle> previousParallelSubtaskState : parallelSubtaskStates) {
            mergeMapList.add(previousParallelSubtaskState.stream().collect(Collectors.toMap(OperatorStateHandle::getDelegateStateHandle, Function.identity())));
        }
        return mergeMapList;
    }

    private Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> collectUnionStates(List<List<OperatorStateHandle>> parallelSubtaskStates) {
        return this.collectStates(parallelSubtaskStates, OperatorStateHandle.Mode.UNION).entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ((StateEntry)e.getValue()).entries));
    }

    private Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> collectPartlyFinishedBroadcastStates(List<List<OperatorStateHandle>> parallelSubtaskStates) {
        return this.collectStates(parallelSubtaskStates, OperatorStateHandle.Mode.BROADCAST).entrySet().stream().filter(e -> ((StateEntry)e.getValue()).isPartiallyReported()).collect(Collectors.toMap(Map.Entry::getKey, e -> ((StateEntry)e.getValue()).entries));
    }

    private Map<String, StateEntry> collectStates(List<List<OperatorStateHandle>> parallelSubtaskStates, OperatorStateHandle.Mode mode) {
        HashMap states = CollectionUtil.newHashMapWithExpectedSize((int)parallelSubtaskStates.size());
        for (int i = 0; i < parallelSubtaskStates.size(); ++i) {
            int subtaskIndex = i;
            List<OperatorStateHandle> subTaskState = parallelSubtaskStates.get(i);
            for (OperatorStateHandle operatorStateHandle : subTaskState) {
                if (operatorStateHandle == null) continue;
                Set<Map.Entry<String, OperatorStateHandle.StateMetaInfo>> partitionOffsetEntries = operatorStateHandle.getStateNameToPartitionOffsets().entrySet();
                partitionOffsetEntries.stream().filter(entry -> ((OperatorStateHandle.StateMetaInfo)entry.getValue()).getDistributionMode().equals((Object)mode)).forEach(entry -> {
                    StateEntry stateEntry = states.computeIfAbsent((String)entry.getKey(), k -> new StateEntry(parallelSubtaskStates.size() * partitionOffsetEntries.size(), parallelSubtaskStates.size()));
                    stateEntry.addEntry(subtaskIndex, (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>)Tuple2.of((Object)operatorStateHandle.getDelegateStateHandle(), (Object)((OperatorStateHandle.StateMetaInfo)entry.getValue())));
                });
            }
        }
        return states;
    }

    private GroupByStateNameResults groupByStateMode(List<List<OperatorStateHandle>> previousParallelSubtaskStates) {
        EnumMap<OperatorStateHandle.Mode, Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> nameToStateByMode = new EnumMap<OperatorStateHandle.Mode, Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>>(OperatorStateHandle.Mode.class);
        for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) {
            nameToStateByMode.put(mode, new HashMap());
        }
        for (List list : previousParallelSubtaskStates) {
            for (OperatorStateHandle operatorStateHandle : list) {
                if (operatorStateHandle == null) continue;
                Set<Map.Entry<String, OperatorStateHandle.StateMetaInfo>> partitionOffsetEntries = operatorStateHandle.getStateNameToPartitionOffsets().entrySet();
                for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e : partitionOffsetEntries) {
                    OperatorStateHandle.StateMetaInfo metaInfo = e.getValue();
                    Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> nameToState = nameToStateByMode.get((Object)metaInfo.getDistributionMode());
                    List stateLocations = nameToState.computeIfAbsent(e.getKey(), k -> new ArrayList(previousParallelSubtaskStates.size() * partitionOffsetEntries.size()));
                    stateLocations.add(Tuple2.of((Object)operatorStateHandle.getDelegateStateHandle(), (Object)e.getValue()));
                }
            }
        }
        return new GroupByStateNameResults(nameToStateByMode);
    }

    private List<Map<StreamStateHandle, OperatorStateHandle>> repartition(GroupByStateNameResults nameToStateByMode, int newParallelism) {
        ArrayList<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList = new ArrayList<Map<StreamStateHandle, OperatorStateHandle>>(newParallelism);
        for (int i = 0; i < newParallelism; ++i) {
            mergeMapList.add(new HashMap());
        }
        Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> nameToDistributeState = nameToStateByMode.getByMode(OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
        this.repartitionSplitState(nameToDistributeState, newParallelism, mergeMapList);
        Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> nameToUnionState = nameToStateByMode.getByMode(OperatorStateHandle.Mode.UNION);
        this.repartitionUnionState(nameToUnionState, mergeMapList);
        Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> nameToBroadcastState = nameToStateByMode.getByMode(OperatorStateHandle.Mode.BROADCAST);
        this.repartitionBroadcastState(nameToBroadcastState, mergeMapList);
        return mergeMapList;
    }

    private void repartitionSplitState(Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> nameToDistributeState, int newParallelism, List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList) {
        int startParallelOp = 0;
        for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e : nameToDistributeState.entrySet()) {
            List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue();
            int totalPartitions = 0;
            for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> offsets : current) {
                totalPartitions += ((OperatorStateHandle.StateMetaInfo)offsets.f1).getOffsets().length;
            }
            int lstIdx = 0;
            int offsetIdx = 0;
            int baseFraction = totalPartitions / newParallelism;
            int remainder = totalPartitions % newParallelism;
            int newStartParallelOp = startParallelOp;
            for (int i = 0; i < newParallelism; ++i) {
                int parallelOpIdx = (i + startParallelOp) % newParallelism;
                int numberOfPartitionsToAssign = baseFraction;
                if (remainder > 0) {
                    ++numberOfPartitionsToAssign;
                    --remainder;
                } else if (remainder == 0) {
                    newStartParallelOp = parallelOpIdx;
                    --remainder;
                }
                while (numberOfPartitionsToAssign > 0) {
                    long[] offs;
                    Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithOffsets = current.get(lstIdx);
                    long[] offsets = ((OperatorStateHandle.StateMetaInfo)handleWithOffsets.f1).getOffsets();
                    int remaining = offsets.length - offsetIdx;
                    if (remaining > numberOfPartitionsToAssign) {
                        offs = Arrays.copyOfRange(offsets, offsetIdx, offsetIdx + numberOfPartitionsToAssign);
                        offsetIdx += numberOfPartitionsToAssign;
                    } else {
                        offs = Arrays.copyOfRange(offsets, offsetIdx, offsets.length);
                        offsetIdx = 0;
                        ++lstIdx;
                    }
                    numberOfPartitionsToAssign -= remaining;
                    Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(parallelOpIdx);
                    OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithOffsets.f0);
                    if (operatorStateHandle == null) {
                        operatorStateHandle = new OperatorStreamStateHandle(CollectionUtil.newHashMapWithExpectedSize((int)nameToDistributeState.size()), (StreamStateHandle)handleWithOffsets.f0);
                        mergeMap.put((StreamStateHandle)handleWithOffsets.f0, operatorStateHandle);
                    }
                    operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
                }
            }
            startParallelOp = newStartParallelOp;
            e.setValue(null);
        }
    }

    private void repartitionUnionState(Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> unionState, List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList) {
        for (Map<StreamStateHandle, OperatorStateHandle> mergeMap : mergeMapList) {
            for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e : unionState.entrySet()) {
                for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithMetaInfo : e.getValue()) {
                    OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0);
                    if (operatorStateHandle == null) {
                        operatorStateHandle = new OperatorStreamStateHandle(CollectionUtil.newHashMapWithExpectedSize((int)unionState.size()), (StreamStateHandle)handleWithMetaInfo.f0);
                        mergeMap.put((StreamStateHandle)handleWithMetaInfo.f0, operatorStateHandle);
                    }
                    operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), (OperatorStateHandle.StateMetaInfo)handleWithMetaInfo.f1);
                }
            }
        }
    }

    private void repartitionBroadcastState(Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> broadcastState, List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList) {
        int newParallelism = mergeMapList.size();
        for (int i = 0; i < newParallelism; ++i) {
            Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(i);
            for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e : broadcastState.entrySet()) {
                int previousParallelism = e.getValue().size();
                Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithMetaInfo = e.getValue().get(i % previousParallelism);
                OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0);
                if (operatorStateHandle == null) {
                    operatorStateHandle = new OperatorStreamStateHandle(CollectionUtil.newHashMapWithExpectedSize((int)broadcastState.size()), (StreamStateHandle)handleWithMetaInfo.f0);
                    mergeMap.put((StreamStateHandle)handleWithMetaInfo.f0, operatorStateHandle);
                }
                operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), (OperatorStateHandle.StateMetaInfo)handleWithMetaInfo.f1);
            }
        }
    }

    private static final class StateEntry {
        final List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> entries;
        final BitSet reportedSubtaskIndices;

        public StateEntry(int estimatedEntrySize, int parallelism) {
            this.entries = new ArrayList<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>(estimatedEntrySize);
            this.reportedSubtaskIndices = new BitSet(parallelism);
        }

        void addEntry(int subtaskIndex, Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> entry) {
            this.entries.add(entry);
            this.reportedSubtaskIndices.set(subtaskIndex);
        }

        boolean isPartiallyReported() {
            return this.reportedSubtaskIndices.cardinality() > 0 && this.reportedSubtaskIndices.cardinality() < this.reportedSubtaskIndices.size();
        }
    }

    private static final class GroupByStateNameResults {
        private final EnumMap<OperatorStateHandle.Mode, Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode;

        GroupByStateNameResults(EnumMap<OperatorStateHandle.Mode, Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode) {
            this.byMode = (EnumMap)Preconditions.checkNotNull(byMode);
        }

        public Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> getByMode(OperatorStateHandle.Mode mode) {
            return this.byMode.get((Object)mode);
        }
    }
}

