package io.improbable.keanu.vertices.dbl.nonprobabilistic.diff;

import com.google.common.base.Preconditions;
import io.improbable.keanu.network.LambdaSection;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.Probabilistic;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.VertexId;
import io.improbable.keanu.vertices.dbl.Differentiator;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/diff/LogProbGradientCalculator.class */
public class LogProbGradientCalculator {
    private final Set<? extends Vertex<?>> logProbOfVertices;
    private final Set<? extends Vertex<?>> wrtVertices;
    private final Map<Vertex, Set<DoubleVertex>> parentToLatentLookup;
    private final Map<Vertex, Set<DoubleVertex>> verticesWithNonzeroDiffWrtLatent;

    public LogProbGradientCalculator(List<? extends Vertex> list, List<? extends Vertex<?>> list2) {
        this.logProbOfVertices = new HashSet(list);
        this.wrtVertices = new HashSet(list2);
        this.parentToLatentLookup = getParentsThatAreConnectedToWrtVertices(this.logProbOfVertices);
        this.verticesWithNonzeroDiffWrtLatent = getVerticesWithNonzeroDiffWrt(this.logProbOfVertices, this.parentToLatentLookup);
    }

    public Map<VertexId, DoubleTensor> getJointLogProbGradientWrtLatents() {
        LogProbGradients logProbGradients = new LogProbGradients();
        Iterator<? extends Vertex<?>> it = this.logProbOfVertices.iterator();
        while (it.hasNext()) {
            logProbGradients.add(reverseModeLogProbGradientWrtLatents(it.next()));
        }
        return logProbGradients.getPartials();
    }

    private Map<Vertex, Set<DoubleVertex>> getVerticesWithNonzeroDiffWrt(Set<? extends Vertex<?>> set, Map<Vertex, Set<DoubleVertex>> map) {
        return (Map) set.stream().collect(Collectors.toMap(vertex -> {
            return vertex;
        }, vertex2 -> {
            Stream<R> map2 = vertex2.getParents().stream().filter(vertex2 -> {
                return vertex2 instanceof DoubleVertex;
            }).map(vertex3 -> {
                return (DoubleVertex) vertex3;
            });
            map.getClass();
            Set set2 = (Set) map2.filter((v1) -> {
                return r1.containsKey(v1);
            }).collect(Collectors.toSet());
            if (!vertex2.isObserved()) {
                set2.add((DoubleVertex) vertex2);
            }
            return set2;
        }));
    }

    private Map<Vertex, Set<DoubleVertex>> getParentsThatAreConnectedToWrtVertices(Set<? extends Vertex> set) {
        HashMap hashMap = new HashMap();
        Iterator<? extends Vertex> it = set.iterator();
        while (it.hasNext()) {
            for (Vertex vertex : it.next().getParents()) {
                Set set2 = (Set) LambdaSection.getUpstreamLambdaSection(vertex, false).getLatentAndObservedVertices().stream().filter(this::isLatentDoubleVertexAndInWrtTo).map(vertex2 -> {
                    return (DoubleVertex) vertex2;
                }).collect(Collectors.toSet());
                if (!set2.isEmpty()) {
                    hashMap.put(vertex, set2);
                }
            }
        }
        return hashMap;
    }

    private boolean isLatentDoubleVertexAndInWrtTo(Vertex vertex) {
        return !vertex.isObserved() && this.wrtVertices.contains(vertex) && (vertex instanceof DoubleVertex);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private LogProbGradients reverseModeLogProbGradientWrtLatents(Vertex vertex) {
        Preconditions.checkArgument(vertex instanceof Probabilistic, "Cannot get logProb gradient on non-probabilistic vertex %s", vertex);
        Map<Vertex, DoubleTensor> dLogProbAtValue = ((Probabilistic) vertex).dLogProbAtValue(this.verticesWithNonzeroDiffWrtLatent.get(vertex));
        LogProbGradients logProbGradients = new LogProbGradients();
        for (Map.Entry<Vertex, DoubleTensor> entry : dLogProbAtValue.entrySet()) {
            DoubleVertex doubleVertex = (DoubleVertex) entry.getKey();
            DoubleTensor value = entry.getValue();
            if (doubleVertex.equals(vertex)) {
                logProbGradients.putWithRespectTo(doubleVertex.getId(), value);
            } else {
                logProbGradients = logProbGradients.add(Differentiator.reverseModeAutoDiff(doubleVertex, AutoDiffBroadcast.correctForBroadcastPartialReverse(new PartialDerivative(value), vertex.getShape(), doubleVertex.getShape()), this.parentToLatentLookup.get(doubleVertex)));
            }
        }
        return logProbGradients;
    }
}
