/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.streams.processor.internals.assignment;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.UUID;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.TopologyMetadata;
import org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils;
import org.apache.kafka.streams.processor.internals.assignment.BalanceSubtopologyGraphConstructor;
import org.apache.kafka.streams.processor.internals.assignment.ClientState;
import org.apache.kafka.streams.processor.internals.assignment.Graph;
import org.apache.kafka.streams.processor.internals.assignment.MinTrafficGraphConstructor;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareGraphConstructor;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.Assertions;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class RackAwareGraphConstructorTest {
    private static final String MIN_COST = "min_cost";
    private static final String BALANCE_SUBTOPOLOGY = "balance_sub_topology";
    private static final int TP_SIZE = 40;
    private static final int PARTITION_SIZE = 3;
    private static final int TOPIC_GROUP_SIZE = 40;
    private static final int CLIENT_SIZE = 20;
    private Graph<Integer> graph;
    private final SortedMap<TaskId, Set<TopicPartition>> taskTopicPartitionMap = AssignmentTestUtils.getTaskTopicPartitionMap(40, 3, false);
    private final SortedSet<TaskId> taskIds = (SortedSet)this.taskTopicPartitionMap.keySet();
    private final List<TaskId> taskIdList = new ArrayList<TaskId>(this.taskIds);
    private final SortedMap<UUID, ClientState> clientStateMap = AssignmentTestUtils.getRandomClientState(20, 40, 3, 1, false, this.taskIds);
    private final List<UUID> clientList = new ArrayList<UUID>(this.clientStateMap.keySet());
    private final Map<TaskId, UUID> taskClientMap = new HashMap<TaskId, UUID>();
    private final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<UUID, Integer>();
    private final Map<TopologyMetadata.Subtopology, Set<TaskId>> tasksForTopicGroup = AssignmentTestUtils.getTasksForTopicGroup(40, 3);
    private RackAwareGraphConstructor constructor;
    @Parameterized.Parameter
    public String constructorType;

    @Parameterized.Parameters(name="constructorType={0}")
    public static Collection<Object[]> getParamStoreType() {
        return Arrays.asList({MIN_COST}, {BALANCE_SUBTOPOLOGY});
    }

    @Before
    public void setUp() {
        this.randomAssignTasksToClient(this.taskIdList, this.clientStateMap);
        if (this.constructorType.equals(MIN_COST)) {
            this.constructor = new MinTrafficGraphConstructor();
        } else if (this.constructorType.equals(BALANCE_SUBTOPOLOGY)) {
            this.constructor = new BalanceSubtopologyGraphConstructor(this.tasksForTopicGroup);
        }
        this.graph = this.constructor.constructTaskGraph(this.clientList, this.taskIdList, this.clientStateMap, this.taskClientMap, this.originalAssignedTaskNumber, ClientState::hasAssignedTask, this::getCost, 10, 1, false, false);
    }

    private int getCost(TaskId taskId, UUID processId, boolean inCurrentAssignment, int trafficCost, int nonOverlapCost, boolean isStandby) {
        return 1;
    }

    @Test
    public void testSubtopologyShouldContainAllTasks() {
        if (this.constructorType.equals(MIN_COST)) {
            return;
        }
        this.taskIdList.add(new TaskId(41, 0));
        Assertions.assertThrows(IllegalStateException.class, () -> {
            this.graph = this.constructor.constructTaskGraph(this.clientList, this.taskIdList, this.clientStateMap, this.taskClientMap, this.originalAssignedTaskNumber, ClientState::hasAssignedTask, this::getCost, 10, 1, false, false);
        });
    }

    @Test
    public void testMinCostGraphConstructor() {
        if (this.constructorType.equals(BALANCE_SUBTOPOLOGY)) {
            return;
        }
        Assertions.assertEquals((long)this.taskIdList.size(), (long)this.graph.flow());
        Assertions.assertEquals((int)(this.taskIdList.size() + this.clientList.size() + 2), (int)this.graph.nodes().size());
        SortedMap edges = this.graph.edges((Comparable)Integer.valueOf(-1));
        for (Graph.Edge edge : edges.values()) {
            Assertions.assertEquals((int)1, (int)edge.flow);
            Assertions.assertEquals((int)1, (int)edge.capacity);
            Assertions.assertEquals((int)0, (int)edge.residualFlow);
            Assertions.assertEquals((int)0, (int)edge.cost);
            Assertions.assertTrue((boolean)edge.forwardEdge);
        }
        for (int taskNodeId = 0; taskNodeId < this.taskIdList.size(); ++taskNodeId) {
            edges = this.graph.edges((Comparable)Integer.valueOf(taskNodeId));
            Assertions.assertEquals((int)this.clientList.size(), (int)edges.size());
            int assignedClient = 0;
            for (Graph.Edge edge : edges.values()) {
                int flow = edge.flow;
                if (flow == 1) {
                    ++assignedClient;
                }
                Assertions.assertEquals((int)1, (int)edge.capacity);
                Assertions.assertEquals((int)(flow == 1 ? 0 : 1), (int)edge.residualFlow);
                Assertions.assertEquals((int)1, (int)edge.cost);
                Assertions.assertTrue((boolean)edge.forwardEdge);
            }
            Assertions.assertEquals((int)1, (int)assignedClient);
            ++taskNodeId;
        }
        int sinkId = this.clientList.size() + this.taskIdList.size();
        int totalFlow = 0;
        for (int i = 0; i < this.clientList.size(); ++i) {
            UUID clientId = this.clientList.get(i);
            int originalAssignedCount = this.originalAssignedTaskNumber.get(clientId);
            int clientNodeId = i + this.taskIdList.size();
            edges = this.graph.edges((Comparable)Integer.valueOf(clientNodeId));
            Assertions.assertEquals((int)1, (int)edges.size());
            for (Map.Entry nodeEdge : edges.entrySet()) {
                Integer nodeId = (Integer)nodeEdge.getKey();
                Assertions.assertEquals((int)sinkId, (Integer)nodeId);
                totalFlow += ((Graph.Edge)nodeEdge.getValue()).flow;
                Assertions.assertEquals((int)originalAssignedCount, (int)((Graph.Edge)nodeEdge.getValue()).capacity);
                Assertions.assertTrue((boolean)((Graph.Edge)nodeEdge.getValue()).forwardEdge);
            }
        }
        Assertions.assertEquals((int)this.taskIdList.size(), (int)totalFlow);
    }

    /*
     * WARNING - void declaration
     */
    @Test
    public void testBalanceSubtopologyGraphConstructor() {
        void var5_13;
        if (this.constructorType.equals(MIN_COST)) {
            return;
        }
        Assertions.assertEquals((long)this.taskIdList.size(), (long)this.graph.flow());
        Assertions.assertEquals((int)(this.taskIdList.size() + 40 * this.clientList.size() + this.clientList.size() + 2), (int)this.graph.nodes().size());
        SortedMap edges = this.graph.edges((Comparable)Integer.valueOf(-1));
        for (Object edge : edges.values()) {
            Assertions.assertEquals((int)1, (int)((Graph.Edge)edge).flow);
            Assertions.assertEquals((int)1, (int)((Graph.Edge)edge).capacity);
            Assertions.assertEquals((int)0, (int)((Graph.Edge)edge).residualFlow);
            Assertions.assertEquals((int)0, (int)((Graph.Edge)edge).cost);
            Assertions.assertTrue((boolean)((Graph.Edge)edge).forwardEdge);
        }
        int taskNodeId = 0;
        for (Set set : this.tasksForTopicGroup.values()) {
            for (int i = 0; i < set.size(); ++i) {
                edges = this.graph.edges((Comparable)Integer.valueOf(taskNodeId));
                Assertions.assertEquals((int)this.clientList.size(), (int)edges.size());
                int assignedClient = 0;
                for (Graph.Edge edge : edges.values()) {
                    int flow = edge.flow;
                    if (flow == 1) {
                        ++assignedClient;
                    }
                    Assertions.assertEquals((int)1, (int)edge.capacity);
                    Assertions.assertEquals((int)(flow == 1 ? 0 : 1), (int)edge.residualFlow);
                    Assertions.assertEquals((int)1, (int)edge.cost);
                    Assertions.assertTrue((boolean)edge.forwardEdge);
                }
                Assertions.assertEquals((int)1, (int)assignedClient);
                ++taskNodeId;
            }
        }
        int topicGroupIndex = 0;
        for (Set<TaskId> set : this.tasksForTopicGroup.values()) {
            int taskCount = set.size();
            for (int j = 0; j < this.clientList.size(); ++j) {
                UUID clientId = this.clientList.get(j);
                int originalAssignedCount = this.originalAssignedTaskNumber.get(clientId);
                int expectedCapacity = (int)Math.ceil((double)originalAssignedCount * 1.0 / (double)this.taskIdList.size() * (double)taskCount);
                int clientNodeId = topicGroupIndex * this.clientList.size() + this.taskIdList.size() + j;
                edges = this.graph.edges((Comparable)Integer.valueOf(clientNodeId));
                Assertions.assertEquals((int)1, (int)edges.size());
                for (Map.Entry nodeEdge : edges.entrySet()) {
                    Integer nodeId = (Integer)nodeEdge.getKey();
                    Assertions.assertEquals((int)(this.clientList.size() * this.tasksForTopicGroup.size() + this.taskIdList.size() + j), (Integer)nodeId);
                    Graph.Edge edge = (Graph.Edge)nodeEdge.getValue();
                    Assertions.assertEquals((int)expectedCapacity, (int)edge.capacity);
                    Assertions.assertEquals((int)0, (int)edge.cost);
                    Assertions.assertTrue((boolean)edge.forwardEdge);
                }
            }
            ++topicGroupIndex;
        }
        int n = this.clientList.size() + this.tasksForTopicGroup.size() * this.clientList.size() + this.taskIdList.size();
        boolean bl = false;
        for (int i = 0; i < this.clientList.size(); ++i) {
            UUID clientId = this.clientList.get(i);
            int originalAssignedCount = this.originalAssignedTaskNumber.get(clientId);
            int clientNodeId = i + this.tasksForTopicGroup.size() * this.clientList.size() + this.taskIdList.size();
            edges = this.graph.edges((Comparable)Integer.valueOf(clientNodeId));
            Assertions.assertEquals((int)1, (int)edges.size());
            for (Map.Entry nodeEdge : edges.entrySet()) {
                Integer nodeId = (Integer)nodeEdge.getKey();
                Assertions.assertEquals((int)n, (Integer)nodeId);
                var5_13 += ((Graph.Edge)nodeEdge.getValue()).flow;
                Assertions.assertEquals((int)originalAssignedCount, (int)((Graph.Edge)nodeEdge.getValue()).capacity);
                Assertions.assertTrue((boolean)((Graph.Edge)nodeEdge.getValue()).forwardEdge);
            }
        }
        Assertions.assertEquals((int)this.taskIdList.size(), (int)var5_13);
    }

    @Test
    public void testAssignTaskFromMinCostFlow() {
        this.graph.solveMinCostFlow();
        this.constructor.assignTaskFromMinCostFlow(this.graph, this.clientList, this.taskIdList, this.clientStateMap, this.originalAssignedTaskNumber, this.taskClientMap, ClientState::assignActive, ClientState::unassignActive, ClientState::hasAssignedTask);
        AssignmentTestUtils.assertValidAssignment(0, this.taskIds, Collections.emptySet(), this.clientStateMap, new StringBuilder());
        if (this.constructorType.equals(BALANCE_SUBTOPOLOGY)) {
            AssignmentTestUtils.assertBalancedTasks(this.clientStateMap);
        }
    }

    private void randomAssignTasksToClient(List<TaskId> taskIdList, SortedMap<UUID, ClientState> clientStateMap) {
        int totalAssigned = 0;
        for (ClientState clientState : clientStateMap.values()) {
            clientState.assignActive(taskIdList.get(totalAssigned++));
            clientState.assignActive(taskIdList.get(totalAssigned++));
        }
        block1: while (totalAssigned < taskIdList.size()) {
            for (ClientState clientState : clientStateMap.values()) {
                if (AssignmentTestUtils.getRandom().nextInt(3) != 0) continue;
                clientState.assignActive(taskIdList.get(totalAssigned));
                if (++totalAssigned < taskIdList.size()) continue;
                continue block1;
            }
        }
    }
}

