/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.particlefiltering;

import io.improbable.keanu.algorithms.graphtraversal.TopologicalSort;
import io.improbable.keanu.vertices.Vertex;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

public class LatentIncrementSort {
    private LatentIncrementSort() {
    }

    public static Map<Vertex, Set<Vertex>> sort(Collection<? extends Vertex> vertices) {
        List<Vertex<?>> verticesWithFewestDependencies;
        Map<Vertex, Set<Vertex>> dependencies = LatentIncrementSort.getObservedVertexLatentDependencies(vertices);
        Map<Vertex, Set<Vertex>> dependants = LatentIncrementSort.mapDependents(dependencies);
        LinkedHashMap<Vertex, Set<Vertex>> observedVertexOrder = new LinkedHashMap<Vertex, Set<Vertex>>();
        while (!(verticesWithFewestDependencies = LatentIncrementSort.getVerticesWithFewestDependencies(dependencies)).isEmpty()) {
            Vertex<?> vertex = verticesWithFewestDependencies.get(0);
            Set<Vertex> vertexDependencies = dependencies.remove(vertex);
            observedVertexOrder.put(vertex, vertexDependencies);
            for (Vertex upstreamVertex : vertexDependencies) {
                LatentIncrementSort.removeDependencyFromOtherVertices(upstreamVertex, dependants, dependencies);
            }
        }
        return observedVertexOrder;
    }

    private static Map<Vertex, Set<Vertex>> getObservedVertexLatentDependencies(Collection<? extends Vertex> vertices) {
        Map<Vertex, Set<Vertex>> dependencies = TopologicalSort.mapDependencies(vertices);
        HashMap<Vertex, Set<Vertex>> observedVertexLatentDependencies = new HashMap<Vertex, Set<Vertex>>();
        for (Map.Entry<Vertex, Set<Vertex>> entry : dependencies.entrySet()) {
            Vertex vertex = entry.getKey();
            if (!vertex.isObserved()) continue;
            Set<Vertex> vertexDependencies = entry.getValue();
            Set<Vertex> latentDependencies = LatentIncrementSort.getLatentDependencies(vertexDependencies);
            observedVertexLatentDependencies.put(vertex, latentDependencies);
        }
        return observedVertexLatentDependencies;
    }

    private static Set<Vertex> getLatentDependencies(Set<Vertex> dependencies) {
        return dependencies.stream().filter(v -> v.isProbabilistic() && !v.isObserved()).collect(Collectors.toSet());
    }

    private static Map<Vertex, Set<Vertex>> mapDependents(Map<Vertex, Set<Vertex>> dependencies) {
        HashMap<Vertex, Set<Vertex>> dependants = new HashMap<Vertex, Set<Vertex>>();
        for (Map.Entry<Vertex, Set<Vertex>> entry : dependencies.entrySet()) {
            Vertex dependant = entry.getKey();
            for (Vertex vertex : entry.getValue()) {
                dependants.computeIfAbsent(vertex, v -> dependants.put((Vertex)v, new HashSet()));
                ((Set)dependants.get(vertex)).add(dependant);
            }
        }
        return dependants;
    }

    private static List<Vertex<?>> getVerticesWithFewestDependencies(Map<Vertex, Set<Vertex>> dependencies) {
        ArrayList verticesWithFewestDependencies = new ArrayList();
        int minDependencies = Integer.MAX_VALUE;
        for (Map.Entry<Vertex, Set<Vertex>> entry : dependencies.entrySet()) {
            Vertex v = entry.getKey();
            int dependsOn = entry.getValue().size();
            if (dependsOn < minDependencies) {
                minDependencies = dependsOn;
                verticesWithFewestDependencies.clear();
                verticesWithFewestDependencies.add(v);
                continue;
            }
            if (dependsOn != minDependencies) continue;
            verticesWithFewestDependencies.add(v);
        }
        return verticesWithFewestDependencies;
    }

    private static void removeDependencyFromOtherVertices(Vertex<?> vertex, Map<Vertex, Set<Vertex>> dependants, Map<Vertex, Set<Vertex>> dependencies) {
        dependants.get(vertex).forEach(dependant -> {
            if (dependencies.containsKey(dependant)) {
                ((Set)dependencies.get(dependant)).remove(vertex);
            }
        });
    }
}

