package io.improbable.keanu.algorithms.variational;

import com.google.common.collect.Iterables;
import io.improbable.keanu.algorithms.NetworkSamples;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.network.NetworkState;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import io.improbable.keanu.vertices.dbl.probabilistic.ProbabilisticDouble;
import java.util.Set;
import java.util.function.Function;

/* loaded from: input_file:io/improbable/keanu/algorithms/variational/KLDivergence.class */
public class KLDivergence {
    public static double compute(QDistribution qDistribution, NetworkSamples networkSamples) {
        qDistribution.getClass();
        return compute(networkSamples, (Function<NetworkState, Double>) qDistribution::getLogOfMasterP);
    }

    public static double compute(ProbabilisticDouble probabilisticDouble, NetworkSamples networkSamples) {
        return compute(networkSamples, (Function<NetworkState, Double>) networkState -> {
            Set<VariableReference> variableReferences = networkState.getVariableReferences();
            if (variableReferences.size() != 1) {
                throw new IllegalArgumentException("A NetworkState does not contain exactly 1 variable and ProbabilisticDouble can only compute the log probability of one value. Try computing KL divergence against a QDistribution instead.");
            }
            return Double.valueOf(probabilisticDouble.logProb(networkState.get((VariableReference) Iterables.getOnlyElement(variableReferences))));
        });
    }

    private static double compute(NetworkSamples networkSamples, Function<NetworkState, Double> function) {
        double d = 0.0d;
        for (int i = 0; i < networkSamples.size(); i++) {
            double logOfMasterP = networkSamples.getLogOfMasterP(i);
            double doubleValue = function.apply(networkSamples.getNetworkState(i)).doubleValue();
            if (!ProbabilityCalculator.isImpossibleLogProb(logOfMasterP)) {
                if (ProbabilityCalculator.isImpossibleLogProb(doubleValue)) {
                    throw new IllegalArgumentException("Q cannot have smaller support than P.");
                }
                d += (logOfMasterP - doubleValue) * Math.exp(logOfMasterP);
            }
        }
        return d;
    }
}
