/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.streams.processor.internals.assignment;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.UUID;
import java.util.function.BiConsumer;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;
import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.TopologyMetadata;
import org.apache.kafka.streams.processor.internals.assignment.ClientState;
import org.apache.kafka.streams.processor.internals.assignment.Graph;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareGraphConstructor;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor;

public class BalanceSubtopologyGraphConstructor
implements RackAwareGraphConstructor {
    private final Map<TopologyMetadata.Subtopology, Set<TaskId>> tasksForTopicGroup;

    public BalanceSubtopologyGraphConstructor(Map<TopologyMetadata.Subtopology, Set<TaskId>> tasksForTopicGroup) {
        this.tasksForTopicGroup = tasksForTopicGroup;
    }

    @Override
    public int getSinkNodeID(List<TaskId> taskIdList, List<UUID> clientList, Map<TopologyMetadata.Subtopology, Set<TaskId>> tasksForTopicGroup) {
        return clientList.size() + taskIdList.size() + clientList.size() * tasksForTopicGroup.size();
    }

    @Override
    public int getClientNodeId(int clientIndex, List<TaskId> taskIdList, List<UUID> clientList, int topicGroupIndex) {
        return taskIdList.size() + clientList.size() * topicGroupIndex + clientIndex;
    }

    @Override
    public int getClientIndex(int clientNodeId, List<TaskId> taskIdList, List<UUID> clientList, int topicGroupIndex) {
        return clientNodeId - taskIdList.size() - clientList.size() * topicGroupIndex;
    }

    private static int getSecondStageClientNodeId(List<TaskId> taskIdList, List<UUID> clientList, Map<TopologyMetadata.Subtopology, Set<TaskId>> tasksForTopicGroup, int clientIndex) {
        return taskIdList.size() + clientList.size() * tasksForTopicGroup.size() + clientIndex;
    }

    @Override
    public Graph<Integer> constructTaskGraph(List<UUID> clientList, List<TaskId> taskIdList, Map<UUID, ClientState> clientStates, Map<TaskId, UUID> taskClientMap, Map<UUID, Integer> originalAssignedTaskNumber, BiPredicate<ClientState, TaskId> hasAssignedTask, RackAwareTaskAssignor.CostFunction costFunction, int trafficCost, int nonOverlapCost, boolean hasReplica, boolean isStandby) {
        this.validateTasks(taskIdList);
        Graph<Integer> graph = new Graph<Integer>();
        for (TaskId taskId : taskIdList) {
            for (Map.Entry<UUID, ClientState> clientState : clientStates.entrySet()) {
                if (!hasAssignedTask.test(clientState.getValue(), taskId)) continue;
                originalAssignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum);
            }
        }
        this.constructEdges(graph, taskIdList, clientList, clientStates, taskClientMap, originalAssignedTaskNumber, hasAssignedTask, costFunction, trafficCost, nonOverlapCost, hasReplica, isStandby);
        long maxFlow = graph.calculateMaxFlow();
        if (maxFlow != (long)taskIdList.size()) {
            throw new IllegalStateException("max flow calculated: " + maxFlow + " doesn't match taskSize: " + taskIdList.size());
        }
        return graph;
    }

    @Override
    public boolean assignTaskFromMinCostFlow(Graph<Integer> graph, List<UUID> clientList, List<TaskId> taskIdList, Map<UUID, ClientState> clientStates, Map<UUID, Integer> originalAssignedTaskNumber, Map<TaskId, UUID> taskClientMap, BiConsumer<ClientState, TaskId> assignTask, BiConsumer<ClientState, TaskId> unAssignTask, BiPredicate<ClientState, TaskId> hasAssignedTask) {
        TreeMap<TopologyMetadata.Subtopology, Set<TaskId>> sortedTasksForTopicGroup = new TreeMap<TopologyMetadata.Subtopology, Set<TaskId>>(this.tasksForTopicGroup);
        HashSet<TaskId> taskIdSet = new HashSet<TaskId>(taskIdList);
        int taskNodeId = 0;
        int topicGroupIndex = 0;
        int tasksAssigned = 0;
        boolean taskMoved = false;
        for (Map.Entry kv : sortedTasksForTopicGroup.entrySet()) {
            TreeSet taskIds = new TreeSet((Collection)kv.getValue());
            for (TaskId taskId : taskIds) {
                if (!taskIdSet.contains(taskId)) continue;
                KeyValue<Boolean, Integer> movedAndAssigned = this.assignTaskToClient(graph, taskId, taskNodeId, topicGroupIndex, clientStates, clientList, taskIdList, taskClientMap, assignTask, unAssignTask);
                taskMoved |= ((Boolean)movedAndAssigned.key).booleanValue();
                tasksAssigned += ((Integer)movedAndAssigned.value).intValue();
                ++taskNodeId;
            }
            ++topicGroupIndex;
        }
        this.validateAssignedTask(taskIdList, tasksAssigned, clientStates, originalAssignedTaskNumber, hasAssignedTask);
        return taskMoved;
    }

    private void validateTasks(List<TaskId> taskIdList) {
        Set tasksInSubtopology = this.tasksForTopicGroup.values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
        for (TaskId taskId : taskIdList) {
            if (tasksInSubtopology.contains(taskId)) continue;
            throw new IllegalStateException("Task " + taskId + " not in tasksForTopicGroup");
        }
    }

    private void constructEdges(Graph<Integer> graph, List<TaskId> taskIdList, List<UUID> clientList, Map<UUID, ClientState> clientStates, Map<TaskId, UUID> taskClientMap, Map<UUID, Integer> originalAssignedTaskNumber, BiPredicate<ClientState, TaskId> hasAssignedTask, RackAwareTaskAssignor.CostFunction costFunction, int trafficCost, int nonOverlapCost, boolean hasReplica, boolean isStandby) {
        TreeSet taskIds;
        HashSet<TaskId> taskIdSet = new HashSet<TaskId>(taskIdList);
        TreeMap<TopologyMetadata.Subtopology, Set<TaskId>> sortedTasksForTopicGroup = new TreeMap<TopologyMetadata.Subtopology, Set<TaskId>>(this.tasksForTopicGroup);
        int sinkId = this.getSinkNodeID(taskIdList, clientList, this.tasksForTopicGroup);
        int taskNodeId = 0;
        int topicGroupIndex = 0;
        for (Map.Entry kv : sortedTasksForTopicGroup.entrySet()) {
            taskIds = new TreeSet((Collection)kv.getValue());
            for (int clientIndex = 0; clientIndex < clientList.size(); ++clientIndex) {
                UUID processId = clientList.get(clientIndex);
                int clientNodeId = this.getClientNodeId(clientIndex, taskIdList, clientList, topicGroupIndex);
                int startingTaskNodeId = taskNodeId;
                int validTaskCount = 0;
                for (TaskId taskId : taskIds) {
                    if (!taskIdSet.contains(taskId)) continue;
                    ++validTaskCount;
                    boolean inCurrentAssignment = hasAssignedTask.test(clientStates.get(processId), taskId);
                    graph.addEdge(startingTaskNodeId, clientNodeId, 1, costFunction.getCost(taskId, processId, inCurrentAssignment, trafficCost, nonOverlapCost, isStandby), 0);
                    ++startingTaskNodeId;
                    if (!inCurrentAssignment) continue;
                    if (!hasReplica && taskClientMap.containsKey(taskId)) {
                        throw new IllegalArgumentException("Task " + taskId + " assigned to multiple clients " + processId + ", " + taskClientMap.get(taskId));
                    }
                    taskClientMap.put(taskId, processId);
                }
                if (validTaskCount <= 0) continue;
                int secondStageClientNodeId = BalanceSubtopologyGraphConstructor.getSecondStageClientNodeId(taskIdList, clientList, this.tasksForTopicGroup, clientIndex);
                int capacity = originalAssignedTaskNumber.containsKey(processId) ? (int)Math.ceil((double)originalAssignedTaskNumber.get(processId).intValue() * 1.0 / (double)taskIdList.size() * (double)validTaskCount) : 0;
                graph.addEdge(clientNodeId, secondStageClientNodeId, capacity, 0, 0);
            }
            taskNodeId += (int)taskIds.stream().filter(taskIdSet::contains).count();
            ++topicGroupIndex;
        }
        taskNodeId = 0;
        for (Map.Entry kv : sortedTasksForTopicGroup.entrySet()) {
            taskIds = new TreeSet((Collection)kv.getValue());
            for (TaskId taskId : taskIds) {
                if (!taskIdSet.contains(taskId)) continue;
                graph.addEdge(-1, taskNodeId, 1, 0, 0);
                ++taskNodeId;
            }
        }
        for (int clientIndex = 0; clientIndex < clientList.size(); ++clientIndex) {
            UUID processId = clientList.get(clientIndex);
            int capacity = originalAssignedTaskNumber.getOrDefault(processId, 0);
            int secondStageClientNodeId = BalanceSubtopologyGraphConstructor.getSecondStageClientNodeId(taskIdList, clientList, this.tasksForTopicGroup, clientIndex);
            graph.addEdge(secondStageClientNodeId, sinkId, capacity, 0, 0);
        }
        graph.setSourceNode(-1);
        graph.setSinkNode(sinkId);
    }
}

