/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.jobgraph.forwardgroup;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroup;
import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil;
import org.apache.flink.runtime.jobgraph.forwardgroup.JobVertexForwardGroup;
import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.runtime.util.JobVertexConnectionUtils;
import org.apache.flink.streaming.api.graph.StreamNode;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.util.Preconditions;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

class ForwardGroupComputeUtilTest {
    ForwardGroupComputeUtilTest() {
    }

    @Test
    void testIsolatedVertices() throws Exception {
        JobVertex v1 = new JobVertex("v1");
        JobVertex v2 = new JobVertex("v2");
        JobVertex v3 = new JobVertex("v3");
        Set<ForwardGroup<?>> groups = ForwardGroupComputeUtilTest.computeForwardGroups(v1, v2, v3);
        ForwardGroupComputeUtilTest.checkGroupSize(groups, 0, new Integer[0]);
    }

    @Test
    void testIsolatedChainedStreamNodeGroups() throws Exception {
        List<StreamNode> topologicallySortedStreamNodes = ForwardGroupComputeUtilTest.createStreamNodes(3);
        Map<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId = Collections.emptyMap();
        Set<ForwardGroup<?>> groups = ForwardGroupComputeUtilTest.computeForwardGroups(topologicallySortedStreamNodes, forwardProducersByConsumerNodeId);
        ForwardGroupComputeUtilTest.checkGroupSize(groups, 3, 1, 1, 1);
    }

    @Test
    void testVariousResultPartitionTypesBetweenVertices() throws Exception {
        this.testThreeVerticesConnectSequentially(false, true, 1, 2);
        this.testThreeVerticesConnectSequentially(false, false, 0, new Integer[0]);
        this.testThreeVerticesConnectSequentially(true, true, 1, 3);
    }

    private void testThreeVerticesConnectSequentially(boolean isForward1, boolean isForward2, int numOfGroups, Integer ... groupSizes) throws Exception {
        JobVertex v1 = new JobVertex("v1");
        JobVertex v2 = new JobVertex("v2");
        JobVertex v3 = new JobVertex("v3");
        JobVertexConnectionUtils.connectNewDataSetAsInput(v2, v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, false, isForward1);
        JobVertexConnectionUtils.connectNewDataSetAsInput(v3, v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING, false, isForward2);
        Set<ForwardGroup<?>> groups = ForwardGroupComputeUtilTest.computeForwardGroups(v1, v2, v3);
        ForwardGroupComputeUtilTest.checkGroupSize(groups, numOfGroups, groupSizes);
    }

    @Test
    void testVariousConnectTypesBetweenChainedStreamNodeGroup() throws Exception {
        this.testThreeChainedStreamNodeGroupsConnectSequentially(false, true, 2, 1, 2);
        this.testThreeChainedStreamNodeGroupsConnectSequentially(false, false, 3, 1, 1, 1);
        this.testThreeChainedStreamNodeGroupsConnectSequentially(true, true, 1, 3);
    }

    private void testThreeChainedStreamNodeGroupsConnectSequentially(boolean isForward1, boolean isForward2, int numOfGroups, Integer ... groupSizes) throws Exception {
        List<StreamNode> topologicallySortedStreamNodes = ForwardGroupComputeUtilTest.createStreamNodes(3);
        HashMap<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId = new HashMap<StreamNode, Set<StreamNode>>();
        if (isForward1) {
            forwardProducersByConsumerNodeId.computeIfAbsent(topologicallySortedStreamNodes.get(1), k -> new HashSet()).add(topologicallySortedStreamNodes.get(0));
        }
        if (isForward2) {
            forwardProducersByConsumerNodeId.computeIfAbsent(topologicallySortedStreamNodes.get(2), k -> new HashSet()).add(topologicallySortedStreamNodes.get(1));
        }
        Set<ForwardGroup<?>> groups = ForwardGroupComputeUtilTest.computeForwardGroups(topologicallySortedStreamNodes, forwardProducersByConsumerNodeId);
        ForwardGroupComputeUtilTest.checkGroupSize(groups, numOfGroups, groupSizes);
    }

    @Test
    void testTwoInputsMergesIntoOne() throws Exception {
        JobVertex v1 = new JobVertex("v1");
        JobVertex v2 = new JobVertex("v2");
        JobVertex v3 = new JobVertex("v3");
        JobVertex v4 = new JobVertex("v4");
        JobVertexConnectionUtils.connectNewDataSetAsInput(v3, v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, false, true);
        JobVertexConnectionUtils.connectNewDataSetAsInput(v3, v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING, false, true);
        JobVertexConnectionUtils.connectNewDataSetAsInput(v4, v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
        Set<ForwardGroup<?>> groups = ForwardGroupComputeUtilTest.computeForwardGroups(v1, v2, v3, v4);
        ForwardGroupComputeUtilTest.checkGroupSize(groups, 1, 3);
    }

    @Test
    void testTwoInputsMergesIntoOneForStreamNodeForwardGroup() throws Exception {
        List<StreamNode> topologicallySortedStreamNodes = ForwardGroupComputeUtilTest.createStreamNodes(4);
        HashMap<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId = new HashMap<StreamNode, Set<StreamNode>>();
        forwardProducersByConsumerNodeId.computeIfAbsent(topologicallySortedStreamNodes.get(2), k -> new HashSet()).add(topologicallySortedStreamNodes.get(0));
        forwardProducersByConsumerNodeId.computeIfAbsent(topologicallySortedStreamNodes.get(2), k -> new HashSet()).add(topologicallySortedStreamNodes.get(1));
        Set<ForwardGroup<?>> groups = ForwardGroupComputeUtilTest.computeForwardGroups(topologicallySortedStreamNodes, forwardProducersByConsumerNodeId);
        ForwardGroupComputeUtilTest.checkGroupSize(groups, 2, 3, 1);
    }

    @Test
    void testOneInputSplitsIntoTwo() throws Exception {
        JobVertex v1 = new JobVertex("v1");
        JobVertex v2 = new JobVertex("v2");
        JobVertex v3 = new JobVertex("v3");
        JobVertex v4 = new JobVertex("v4");
        JobVertexConnectionUtils.connectNewDataSetAsInput(v2, v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
        JobVertexConnectionUtils.connectNewDataSetAsInput(v3, v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING, false, true);
        JobVertexConnectionUtils.connectNewDataSetAsInput(v4, v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING, false, true);
        Set<ForwardGroup<?>> groups = ForwardGroupComputeUtilTest.computeForwardGroups(v1, v2, v3, v4);
        ForwardGroupComputeUtilTest.checkGroupSize(groups, 1, 3);
    }

    @Test
    void testOneInputSplitsIntoTwoForStreamNodeForwardGroup() throws Exception {
        List<StreamNode> topologicallySortedStreamNodes = ForwardGroupComputeUtilTest.createStreamNodes(4);
        HashMap<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId = new HashMap<StreamNode, Set<StreamNode>>();
        forwardProducersByConsumerNodeId.computeIfAbsent(topologicallySortedStreamNodes.get(3), k -> new HashSet()).add(topologicallySortedStreamNodes.get(1));
        forwardProducersByConsumerNodeId.computeIfAbsent(topologicallySortedStreamNodes.get(2), k -> new HashSet()).add(topologicallySortedStreamNodes.get(1));
        Set<ForwardGroup<?>> groups = ForwardGroupComputeUtilTest.computeForwardGroups(topologicallySortedStreamNodes, forwardProducersByConsumerNodeId);
        ForwardGroupComputeUtilTest.checkGroupSize(groups, 2, 3, 1);
    }

    private static Set<ForwardGroup<?>> computeForwardGroups(JobVertex ... vertices) {
        Arrays.asList(vertices).forEach(vertex -> vertex.setInvokableClass(NoOpInvokable.class));
        return new HashSet(ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism(Arrays.asList(vertices)).values());
    }

    private static void checkGroupSize(Set<ForwardGroup<?>> groups, int numOfGroups, Integer ... sizes) {
        Assertions.assertThat((int)groups.size()).isEqualTo(numOfGroups);
        Assertions.assertThat(groups.stream().map(group -> {
            if (group instanceof JobVertexForwardGroup) {
                return ((JobVertexForwardGroup)group).size();
            }
            return ((StreamNodeForwardGroup)group).size();
        }).collect(Collectors.toList())).contains((Object[])sizes);
    }

    private static StreamNode createStreamNode(int id) {
        return new StreamNode(Integer.valueOf(id), null, null, (StreamOperator)null, null, null);
    }

    private static List<StreamNode> createStreamNodes(int count) {
        ArrayList<StreamNode> streamNodes = new ArrayList<StreamNode>();
        for (int i = 1; i <= count; ++i) {
            streamNodes.add(new StreamNode(Integer.valueOf(i), null, null, (StreamOperator)null, null, null));
        }
        return streamNodes;
    }

    private static Set<ForwardGroup<?>> computeForwardGroups(List<StreamNode> topologicallySortedStreamNodes, Map<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId) {
        return new HashSet<StreamNodeForwardGroup>(ForwardGroupComputeUtilTest.computeStreamNodeForwardGroupAndCheckParallelism(topologicallySortedStreamNodes, id -> forwardProducersByConsumerNodeId.getOrDefault(id, Collections.emptySet())).values());
    }

    public static Map<Integer, StreamNodeForwardGroup> computeStreamNodeForwardGroupAndCheckParallelism(Iterable<StreamNode> topologicallySortedStreamNodes, Function<StreamNode, Set<StreamNode>> forwardProducersRetriever) {
        Map forwardGroupsByStartNodeId = ForwardGroupComputeUtil.computeStreamNodeForwardGroup(topologicallySortedStreamNodes, forwardProducersRetriever);
        topologicallySortedStreamNodes.forEach(startNode -> {
            StreamNodeForwardGroup forwardGroup = (StreamNodeForwardGroup)forwardGroupsByStartNodeId.get(startNode.getId());
            if (forwardGroup != null && forwardGroup.isParallelismDecided()) {
                Preconditions.checkState((startNode.getParallelism() == forwardGroup.getParallelism() ? 1 : 0) != 0);
            }
        });
        return forwardGroupsByStartNodeId;
    }
}

