/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.jobgraph.forwardgroup;

import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.flink.runtime.executiongraph.VertexGroupComputeUtil;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroup;
import org.apache.flink.runtime.jobgraph.forwardgroup.JobVertexForwardGroup;
import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup;
import org.apache.flink.streaming.api.graph.StreamNode;
import org.apache.flink.util.Preconditions;

public class ForwardGroupComputeUtil {
    public static Map<JobVertexID, JobVertexForwardGroup> computeForwardGroupsAndCheckParallelism(Iterable<JobVertex> topologicallySortedVertices) {
        Map<JobVertexID, JobVertexForwardGroup> forwardGroupsByJobVertexId = ForwardGroupComputeUtil.computeForwardGroups(topologicallySortedVertices, ForwardGroupComputeUtil::getForwardProducers);
        topologicallySortedVertices.forEach(jobVertex -> {
            JobVertexForwardGroup forwardGroup = (JobVertexForwardGroup)forwardGroupsByJobVertexId.get(jobVertex.getID());
            if (forwardGroup != null && forwardGroup.isParallelismDecided()) {
                Preconditions.checkState((jobVertex.getParallelism() == forwardGroup.getParallelism() ? 1 : 0) != 0);
            }
        });
        return forwardGroupsByJobVertexId;
    }

    public static Map<JobVertexID, JobVertexForwardGroup> computeForwardGroups(Iterable<JobVertex> topologicallySortedVertices, Function<JobVertex, Set<JobVertex>> forwardProducersRetriever) {
        Map<JobVertex, Set<JobVertex>> vertexToGroup = ForwardGroupComputeUtil.computeVertexToGroup(topologicallySortedVertices, forwardProducersRetriever);
        HashMap<JobVertexID, JobVertexForwardGroup> ret = new HashMap<JobVertexID, JobVertexForwardGroup>();
        for (Set<JobVertex> vertexGroup : VertexGroupComputeUtil.uniqueVertexGroups(vertexToGroup)) {
            if (vertexGroup.size() <= 1) continue;
            JobVertexForwardGroup forwardGroup = new JobVertexForwardGroup(vertexGroup);
            for (JobVertexID jobVertexId : forwardGroup.getVertexIds()) {
                ret.put(jobVertexId, forwardGroup);
            }
        }
        return ret;
    }

    public static Map<Integer, StreamNodeForwardGroup> computeStreamNodeForwardGroup(Iterable<StreamNode> topologicallySortedStreamNodes, Function<StreamNode, Set<StreamNode>> forwardProducersRetriever) {
        Map<StreamNode, Set<StreamNode>> nodeToGroup = ForwardGroupComputeUtil.computeVertexToGroup(topologicallySortedStreamNodes, forwardProducersRetriever);
        HashMap<Integer, StreamNodeForwardGroup> ret = new HashMap<Integer, StreamNodeForwardGroup>();
        for (Set<StreamNode> nodeGroup : VertexGroupComputeUtil.uniqueVertexGroups(nodeToGroup)) {
            StreamNodeForwardGroup forwardGroup = new StreamNodeForwardGroup(nodeGroup);
            for (Integer vertexId : forwardGroup.getVertexIds()) {
                ret.put(vertexId, forwardGroup);
            }
        }
        return ret;
    }

    private static <T> Map<T, Set<T>> computeVertexToGroup(Iterable<T> topologicallySortedVertices, Function<T, Set<T>> forwardProducersRetriever) {
        IdentityHashMap vertexToGroup = new IdentityHashMap();
        for (T vertex : topologicallySortedVertices) {
            Set<Object> currentGroup = new HashSet<T>();
            currentGroup.add(vertex);
            vertexToGroup.put(vertex, currentGroup);
            for (T producerVertex : forwardProducersRetriever.apply(vertex)) {
                Set producerGroup = (Set)vertexToGroup.get(producerVertex);
                if (producerGroup == null) {
                    throw new IllegalStateException("Producer task " + String.valueOf(producerVertex) + " forward group is null while calculating forward group for the consumer task " + String.valueOf(vertex) + ". This should be a forward group building bug.");
                }
                if (currentGroup == producerGroup) continue;
                currentGroup = VertexGroupComputeUtil.mergeVertexGroups(currentGroup, producerGroup, vertexToGroup);
            }
        }
        return vertexToGroup;
    }

    public static boolean canTargetMergeIntoSourceForwardGroup(ForwardGroup<?> sourceForwardGroup, ForwardGroup<?> forwardGroupToMerge) {
        if (sourceForwardGroup == null || forwardGroupToMerge == null) {
            return false;
        }
        if (sourceForwardGroup == forwardGroupToMerge) {
            return true;
        }
        if (sourceForwardGroup.isParallelismDecided() && forwardGroupToMerge.isParallelismDecided() && sourceForwardGroup.getParallelism() != forwardGroupToMerge.getParallelism()) {
            return false;
        }
        return !sourceForwardGroup.isParallelismDecided() || !forwardGroupToMerge.isMaxParallelismDecided() || sourceForwardGroup.getParallelism() <= forwardGroupToMerge.getMaxParallelism();
    }

    static Set<JobVertex> getForwardProducers(JobVertex jobVertex) {
        return jobVertex.getInputs().stream().filter(JobEdge::isForward).map(JobEdge::getSource).map(IntermediateDataSet::getProducer).collect(Collectors.toSet());
    }
}

