package io.improbable.keanu.algorithms.particlefiltering;

import io.improbable.keanu.algorithms.graphtraversal.TopologicalSort;
import io.improbable.keanu.vertices.Vertex;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:io/improbable/keanu/algorithms/particlefiltering/LatentIncrementSort.class */
public class LatentIncrementSort {
    private LatentIncrementSort() {
    }

    public static Map<Vertex, Set<Vertex>> sort(Collection<? extends Vertex> collection) {
        Map<Vertex, Set<Vertex>> observedVertexLatentDependencies = getObservedVertexLatentDependencies(collection);
        Map<Vertex, Set<Vertex>> mapDependents = mapDependents(observedVertexLatentDependencies);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        while (true) {
            List<Vertex<?>> verticesWithFewestDependencies = getVerticesWithFewestDependencies(observedVertexLatentDependencies);
            if (verticesWithFewestDependencies.isEmpty()) {
                return linkedHashMap;
            }
            Vertex<?> vertex = verticesWithFewestDependencies.get(0);
            Set<Vertex> remove = observedVertexLatentDependencies.remove(vertex);
            linkedHashMap.put(vertex, remove);
            Iterator<Vertex> it = remove.iterator();
            while (it.hasNext()) {
                removeDependencyFromOtherVertices(it.next(), mapDependents, observedVertexLatentDependencies);
            }
        }
    }

    private static Map<Vertex, Set<Vertex>> getObservedVertexLatentDependencies(Collection<? extends Vertex> collection) {
        Map<Vertex, Set<Vertex>> mapDependencies = TopologicalSort.mapDependencies(collection);
        HashMap hashMap = new HashMap();
        for (Map.Entry<Vertex, Set<Vertex>> entry : mapDependencies.entrySet()) {
            Vertex key = entry.getKey();
            if (key.isObserved()) {
                hashMap.put(key, getLatentDependencies(entry.getValue()));
            }
        }
        return hashMap;
    }

    private static Set<Vertex> getLatentDependencies(Set<Vertex> set) {
        return (Set) set.stream().filter(vertex -> {
            return vertex.isProbabilistic() && !vertex.isObserved();
        }).collect(Collectors.toSet());
    }

    private static Map<Vertex, Set<Vertex>> mapDependents(Map<Vertex, Set<Vertex>> map) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<Vertex, Set<Vertex>> entry : map.entrySet()) {
            Vertex key = entry.getKey();
            for (Vertex vertex : entry.getValue()) {
                hashMap.computeIfAbsent(vertex, vertex2 -> {
                    return (Set) hashMap.put(vertex2, new HashSet());
                });
                ((Set) hashMap.get(vertex)).add(key);
            }
        }
        return hashMap;
    }

    private static List<Vertex<?>> getVerticesWithFewestDependencies(Map<Vertex, Set<Vertex>> map) {
        ArrayList arrayList = new ArrayList();
        int i = Integer.MAX_VALUE;
        for (Map.Entry<Vertex, Set<Vertex>> entry : map.entrySet()) {
            Vertex key = entry.getKey();
            int size = entry.getValue().size();
            if (size < i) {
                i = size;
                arrayList.clear();
                arrayList.add(key);
            } else if (size == i) {
                arrayList.add(key);
            }
        }
        return arrayList;
    }

    private static void removeDependencyFromOtherVertices(Vertex<?> vertex, Map<Vertex, Set<Vertex>> map, Map<Vertex, Set<Vertex>> map2) {
        map.get(vertex).forEach(vertex2 -> {
            if (map2.containsKey(vertex2)) {
                ((Set) map2.get(vertex2)).remove(vertex);
            }
        });
    }
}
