package io.improbable.keanu.vertices.dbl;

import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialsOf;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialsWithRespectTo;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/Differentiator.class */
public final class Differentiator {
    public static <V extends Vertex & Differentiable> PartialsWithRespectTo forwardModeAutoDiff(V v, V... vArr) {
        return forwardModeAutoDiff(v, new HashSet(Arrays.asList(vArr)));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <V extends Vertex & Differentiable> PartialsWithRespectTo forwardModeAutoDiff(V v, Collection<V> collection) {
        PriorityQueue priorityQueue = new PriorityQueue(Comparator.comparing((v0) -> {
            return v0.getId();
        }, Comparator.naturalOrder()));
        priorityQueue.add(v);
        HashSet hashSet = new HashSet();
        hashSet.add(v);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        while (!priorityQueue.isEmpty()) {
            Vertex vertex = (Vertex) priorityQueue.poll();
            PartialDerivative forwardModeAutoDifferentiation = ((Differentiable) vertex).forwardModeAutoDifferentiation(hashMap);
            hashMap.put(vertex, forwardModeAutoDifferentiation);
            if (collection.contains(vertex)) {
                hashMap2.put(vertex.getId(), forwardModeAutoDifferentiation);
            } else {
                for (Vertex vertex2 : vertex.getChildren()) {
                    if (!vertex2.isProbabilistic() && !hashSet.contains(vertex2) && vertex2.isDifferentiable()) {
                        priorityQueue.offer(vertex2);
                        hashSet.add(vertex2);
                    }
                }
            }
        }
        return new PartialsWithRespectTo(v, hashMap2);
    }

    public static PartialsOf reverseModeAutoDiff(Vertex vertex, Set<? extends Vertex<?>> set) {
        return vertex.isObserved() ? new PartialsOf(vertex, Collections.emptyMap()) : reverseModeAutoDiff(vertex, Differentiable.withRespectToSelf(vertex.getShape()), set);
    }

    public static PartialsOf reverseModeAutoDiff(Vertex vertex, Vertex<?>... vertexArr) {
        return reverseModeAutoDiff(vertex, new HashSet(Arrays.asList(vertexArr)));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static PartialsOf reverseModeAutoDiff(Vertex<?> vertex, PartialDerivative partialDerivative, Set<? extends Vertex<?>> set) {
        ensureGraphValuesAndShapesAreSet(vertex);
        PriorityQueue priorityQueue = new PriorityQueue(Comparator.comparing((v0) -> {
            return v0.getId();
        }, Comparator.naturalOrder()).reversed());
        priorityQueue.add(vertex);
        HashSet hashSet = new HashSet();
        hashSet.add(vertex);
        HashMap hashMap = new HashMap();
        hashMap.put(vertex, partialDerivative);
        HashMap hashMap2 = new HashMap();
        while (true) {
            Vertex vertex2 = (Vertex) priorityQueue.poll();
            if (vertex2 == 0) {
                return new PartialsOf(vertex, hashMap2);
            }
            if (set.contains(vertex2)) {
                hashMap2.put(vertex2.getId(), hashMap.get(vertex2));
            } else if (!vertex2.isProbabilistic() && vertex2.isDifferentiable()) {
                Differentiable differentiable = (Differentiable) vertex2;
                PartialDerivative partialDerivative2 = (PartialDerivative) hashMap.get(vertex2);
                if (partialDerivative2 != null) {
                    collectPartials(differentiable.reverseModeAutoDifferentiation(partialDerivative2), hashMap);
                    for (Vertex vertex3 : vertex2.getParents()) {
                        if (!hashSet.contains(vertex3) && vertex3.isDifferentiable()) {
                            priorityQueue.offer(vertex3);
                            hashSet.add(vertex3);
                        }
                    }
                }
            }
        }
    }

    private static void ensureGraphValuesAndShapesAreSet(Vertex<?> vertex) {
        vertex.getValue();
    }

    private static void collectPartials(Map<Vertex, PartialDerivative> map, Map<Vertex, PartialDerivative> map2) {
        for (Map.Entry<Vertex, PartialDerivative> entry : map.entrySet()) {
            Vertex key = entry.getKey();
            PartialDerivative value = entry.getValue();
            if (map2.containsKey(key)) {
                map2.put(key, map2.get(key).add(value));
            } else {
                map2.put(key, value);
            }
        }
    }

    private Differentiator() {
        throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
    }
}
