package io.improbable.keanu.network;

import io.improbable.keanu.algorithms.ProbabilisticModelWithGradient;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.LogProbGradientCalculator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:io/improbable/keanu/network/KeanuProbabilisticModelWithGradient.class */
public class KeanuProbabilisticModelWithGradient extends KeanuProbabilisticModel implements ProbabilisticModelWithGradient {
    private final LogProbGradientCalculator logProbGradientCalculator;
    private final LogProbGradientCalculator logLikelihoodGradientCalculator;

    public KeanuProbabilisticModelWithGradient(BayesianNetwork bayesianNetwork) {
        super(bayesianNetwork);
        List<Vertex<DoubleTensor>> continuousLatentVertices = bayesianNetwork.getContinuousLatentVertices();
        this.logProbGradientCalculator = new LogProbGradientCalculator(bayesianNetwork.getLatentOrObservedVertices(), continuousLatentVertices);
        this.logLikelihoodGradientCalculator = new LogProbGradientCalculator(bayesianNetwork.getObservedVertices(), continuousLatentVertices);
    }

    public KeanuProbabilisticModelWithGradient(Set<Vertex> set) {
        this(new BayesianNetwork((Set<? extends Vertex>) set));
    }

    @Override // io.improbable.keanu.algorithms.ProbabilisticModelWithGradient
    public Map<VariableReference, DoubleTensor> logProbGradients(Map<VariableReference, ?> map) {
        return gradients(map, this.logProbGradientCalculator);
    }

    @Override // io.improbable.keanu.algorithms.ProbabilisticModelWithGradient
    public Map<VariableReference, DoubleTensor> logProbGradients() {
        return logProbGradients(null);
    }

    @Override // io.improbable.keanu.algorithms.ProbabilisticModelWithGradient
    public Map<VariableReference, DoubleTensor> logLikelihoodGradients(Map<VariableReference, ?> map) {
        return gradients(map, this.logLikelihoodGradientCalculator);
    }

    @Override // io.improbable.keanu.algorithms.ProbabilisticModelWithGradient
    public Map<VariableReference, DoubleTensor> logLikelihoodGradients() {
        return logLikelihoodGradients(null);
    }

    private Map gradients(Map<VariableReference, ?> map, LogProbGradientCalculator logProbGradientCalculator) {
        if (map != null && !map.isEmpty()) {
            cascadeValues(map);
        }
        return logProbGradientCalculator.getJointLogProbGradientWrtLatents();
    }
}
