package io.improbable.keanu.algorithms.graphtraversal;

import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.Probabilistic;
import io.improbable.keanu.vertices.Vertex;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Deque;
import java.util.HashSet;
import java.util.Iterator;
import java.util.PriorityQueue;
import java.util.Set;

/* loaded from: input_file:io/improbable/keanu/algorithms/graphtraversal/VertexValuePropagation.class */
public class VertexValuePropagation {
    private VertexValuePropagation() {
    }

    public static void cascadeUpdate(Vertex... vertexArr) {
        cascadeUpdate(Arrays.asList(vertexArr));
    }

    public static void cascadeUpdate(Vertex vertex) {
        cascadeUpdate(Collections.singletonList(vertex));
    }

    public static void cascadeUpdate(Collection<? extends Vertex> collection) {
        PriorityQueue priorityQueue = new PriorityQueue(Comparator.comparing((v0) -> {
            return v0.getId();
        }, Comparator.naturalOrder()));
        priorityQueue.addAll(collection);
        HashSet hashSet = new HashSet(collection);
        while (!priorityQueue.isEmpty()) {
            Vertex vertex = (Vertex) priorityQueue.poll();
            updateVertexValue(vertex);
            for (Vertex vertex2 : vertex.getChildren()) {
                if (!vertex2.isProbabilistic() && !hashSet.contains(vertex2)) {
                    priorityQueue.offer(vertex2);
                    hashSet.add(vertex2);
                }
            }
        }
    }

    public static void eval(Vertex... vertexArr) {
        eval(Arrays.asList(vertexArr));
    }

    public static void eval(Collection<? extends Vertex> collection) {
        Deque<Vertex> asDeque = asDeque(collection);
        HashSet hashSet = new HashSet();
        while (!asDeque.isEmpty()) {
            Vertex peek = asDeque.peek();
            Set<Vertex<?>> parentsThatAreNotCalculated = parentsThatAreNotCalculated(hashSet, peek.getParents());
            if (peek.isProbabilistic() || parentsThatAreNotCalculated.isEmpty()) {
                Vertex pop = asDeque.pop();
                updateVertexValue(pop);
                hashSet.add(pop);
            } else {
                Iterator<Vertex<?>> it = parentsThatAreNotCalculated.iterator();
                while (it.hasNext()) {
                    asDeque.push(it.next());
                }
            }
        }
    }

    private static Set<Vertex<?>> parentsThatAreNotCalculated(Set<Vertex<?>> set, Collection<Vertex> collection) {
        HashSet hashSet = new HashSet();
        for (Vertex vertex : collection) {
            if (!set.contains(vertex)) {
                hashSet.add(vertex);
            }
        }
        return hashSet;
    }

    public static void lazyEval(Vertex... vertexArr) {
        lazyEval(Arrays.asList(vertexArr));
    }

    public static void lazyEval(Collection<? extends Vertex> collection) {
        Deque<Vertex> asDeque = asDeque(collection);
        while (!asDeque.isEmpty()) {
            Vertex peek = asDeque.peek();
            Set<Vertex<?>> parentsThatAreNotCalculated = parentsThatAreNotCalculated(peek.getParents());
            if (peek.isProbabilistic() || parentsThatAreNotCalculated.isEmpty()) {
                updateVertexValue(asDeque.pop());
            } else {
                Iterator<Vertex<?>> it = parentsThatAreNotCalculated.iterator();
                while (it.hasNext()) {
                    asDeque.push(it.next());
                }
            }
        }
    }

    private static Set<Vertex<?>> parentsThatAreNotCalculated(Collection<Vertex> collection) {
        HashSet hashSet = new HashSet();
        for (Vertex vertex : collection) {
            if (!vertex.hasValue()) {
                hashSet.add(vertex);
            }
        }
        return hashSet;
    }

    private static Deque<Vertex> asDeque(Iterable<? extends Vertex> iterable) {
        ArrayDeque arrayDeque = new ArrayDeque();
        Iterator<? extends Vertex> it = iterable.iterator();
        while (it.hasNext()) {
            arrayDeque.push(it.next());
        }
        return arrayDeque;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <T> void updateVertexValue(Vertex<T> vertex) {
        if (vertex.isProbabilistic()) {
            if (vertex.hasValue()) {
                return;
            }
            vertex.setValue(((Probabilistic) vertex).sample());
        } else {
            if (vertex.isObserved()) {
                return;
            }
            vertex.setValue(((NonProbabilistic) vertex).calculate());
        }
    }
}
