/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl;

import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.VertexId;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialsOf;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialsWithRespectTo;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;

public final class Differentiator {
    public static <V extends Vertex> PartialsWithRespectTo forwardModeAutoDiff(V wrt, V ... of) {
        return Differentiator.forwardModeAutoDiff(wrt, new HashSet<V>(Arrays.asList(of)));
    }

    public static <V extends Vertex> PartialsWithRespectTo forwardModeAutoDiff(V wrt, Collection<V> of) {
        PriorityQueue<Vertex> priorityQueue = new PriorityQueue<Vertex>(Comparator.comparing(Vertex::getId, Comparator.naturalOrder()));
        priorityQueue.add(wrt);
        HashSet<Vertex> alreadyQueued = new HashSet<Vertex>();
        alreadyQueued.add(wrt);
        HashMap<Vertex, PartialDerivative> partials = new HashMap<Vertex, PartialDerivative>();
        HashMap<VertexId, PartialDerivative> ofWrt = new HashMap<VertexId, PartialDerivative>();
        while (!priorityQueue.isEmpty()) {
            Vertex visiting = priorityQueue.poll();
            PartialDerivative partialOfVisiting = ((Differentiable)((Object)visiting)).forwardModeAutoDifferentiation(partials);
            partials.put(visiting, partialOfVisiting);
            if (of.contains(visiting)) {
                ofWrt.put(visiting.getId(), partialOfVisiting);
                continue;
            }
            for (Vertex child : visiting.getChildren()) {
                if (child.isProbabilistic() || alreadyQueued.contains(child) || !child.isDifferentiable()) continue;
                priorityQueue.offer(child);
                alreadyQueued.add(child);
            }
        }
        return new PartialsWithRespectTo(wrt, ofWrt);
    }

    public static PartialsOf reverseModeAutoDiff(Vertex ofVertex, Set<? extends Vertex<?>> wrt) {
        if (ofVertex.isObserved()) {
            return new PartialsOf(ofVertex, Collections.emptyMap());
        }
        return Differentiator.reverseModeAutoDiff(ofVertex, Differentiable.withRespectToSelf(ofVertex.getShape()), wrt);
    }

    public static PartialsOf reverseModeAutoDiff(Vertex ofVertex, Vertex<?> ... wrt) {
        return Differentiator.reverseModeAutoDiff(ofVertex, new HashSet(Arrays.asList(wrt)));
    }

    public static PartialsOf reverseModeAutoDiff(Vertex<?> ofVertex, PartialDerivative dWrtOfVertex, Set<? extends Vertex<?>> wrt) {
        Vertex visiting;
        Differentiator.ensureGraphValuesAndShapesAreSet(ofVertex);
        PriorityQueue<Vertex> priorityQueue = new PriorityQueue<Vertex>(Comparator.comparing(Vertex::getId, Comparator.naturalOrder()).reversed());
        priorityQueue.add(ofVertex);
        HashSet<Vertex> alreadyQueued = new HashSet<Vertex>();
        alreadyQueued.add(ofVertex);
        HashMap<Vertex, PartialDerivative> dwrtOf = new HashMap<Vertex, PartialDerivative>();
        dwrtOf.put(ofVertex, dWrtOfVertex);
        HashMap<VertexId, PartialDerivative> wrtOf = new HashMap<VertexId, PartialDerivative>();
        while ((visiting = priorityQueue.poll()) != null) {
            if (wrt.contains(visiting)) {
                wrtOf.put(visiting.getId(), (PartialDerivative)dwrtOf.get(visiting));
                continue;
            }
            if (visiting.isProbabilistic() || !visiting.isDifferentiable()) continue;
            Differentiable visitingDifferentiable = (Differentiable)((Object)visiting);
            PartialDerivative derivativeOfOutputWrtVisiting = (PartialDerivative)dwrtOf.get(visiting);
            if (derivativeOfOutputWrtVisiting == null) continue;
            Map<Vertex, PartialDerivative> partialDerivatives = visitingDifferentiable.reverseModeAutoDifferentiation(derivativeOfOutputWrtVisiting);
            Differentiator.collectPartials(partialDerivatives, dwrtOf);
            for (Vertex parent : visiting.getParents()) {
                if (alreadyQueued.contains(parent) || !parent.isDifferentiable()) continue;
                priorityQueue.offer(parent);
                alreadyQueued.add(parent);
            }
        }
        return new PartialsOf(ofVertex, wrtOf);
    }

    private static void ensureGraphValuesAndShapesAreSet(Vertex<?> vertex) {
        vertex.getValue();
    }

    private static void collectPartials(Map<Vertex, PartialDerivative> partialDerivatives, Map<Vertex, PartialDerivative> dwrtOf) {
        for (Map.Entry<Vertex, PartialDerivative> v : partialDerivatives.entrySet()) {
            Vertex wrtVertex = v.getKey();
            PartialDerivative dwrtV = v.getValue();
            if (dwrtOf.containsKey(wrtVertex)) {
                dwrtOf.put(wrtVertex, dwrtOf.get(wrtVertex).add(dwrtV));
                continue;
            }
            dwrtOf.put(wrtVertex, dwrtV);
        }
    }

    private Differentiator() {
        throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
    }
}

