package io.improbable.keanu.algorithms.mcmc.proposal;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.distributions.continuous.Gaussian;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.Probabilistic;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.nd4j.base.Preconditions;

/* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/proposal/GaussianProposalDistribution.class */
public class GaussianProposalDistribution implements ProposalDistribution {
    private final Map<? extends Variable, DoubleTensor> sigmas;
    private final ProposalNotifier proposalNotifier;

    public GaussianProposalDistribution(List<? extends Variable> list, DoubleTensor doubleTensor) {
        this(list, doubleTensor, Collections.emptyList());
    }

    public GaussianProposalDistribution(List<? extends Variable> list, DoubleTensor doubleTensor, List<ProposalListener> list2) {
        this(toSigmasMap(list, doubleTensor), list2);
    }

    private static Map<? extends Variable, DoubleTensor> toSigmasMap(Collection<? extends Variable> collection, DoubleTensor doubleTensor) {
        HashMap hashMap = new HashMap();
        Iterator<? extends Variable> it = collection.iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), doubleTensor);
        }
        return hashMap;
    }

    public GaussianProposalDistribution(Map<? extends Variable, DoubleTensor> map) {
        this(map, (List<ProposalListener>) Collections.emptyList());
    }

    public GaussianProposalDistribution(Map<? extends Variable, DoubleTensor> map, List<ProposalListener> list) {
        Preconditions.checkArgument(map.size() > 0, "Gaussian proposal requires at least one sigma");
        this.sigmas = map;
        this.proposalNotifier = new ProposalNotifier(list);
    }

    @Override // io.improbable.keanu.algorithms.mcmc.proposal.ProposalDistribution
    public Proposal getProposal(Set<? extends Variable> set, KeanuRandom keanuRandom) {
        Proposal proposal = new Proposal();
        for (Variable variable : set) {
            if (!(variable.getValue() instanceof DoubleTensor)) {
                throw new IllegalStateException("Gaussian proposal function cannot be used for discrete variable " + variable);
            }
            if (!this.sigmas.containsKey(variable)) {
                throw new IllegalStateException("A sigma was not specified for variable " + variable);
            }
            proposal.setProposal(variable, keanuRandom.nextGaussian(variable.getShape(), (DoubleTensor) variable.getValue(), this.sigmas.get(variable)));
        }
        this.proposalNotifier.notifyProposalCreated(proposal);
        return proposal;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.algorithms.mcmc.proposal.ProposalDistribution
    public <T> double logProb(Probabilistic<T> probabilistic, T t, T t2) {
        if (!(t instanceof DoubleTensor)) {
            throw new ClassCastException("Only DoubleTensor values are supported - not " + t.getClass().getSimpleName());
        }
        if (this.sigmas.containsKey(probabilistic)) {
            return ((Double) ((Gaussian) Gaussian.withParameters((DoubleTensor) t, this.sigmas.get(probabilistic))).logProb((DoubleTensor) t2).sum()).doubleValue();
        }
        throw new IllegalStateException("A sigma was not specified for variable " + probabilistic);
    }

    @Override // io.improbable.keanu.algorithms.mcmc.proposal.ProposalDistribution
    public void onProposalRejected() {
        this.proposalNotifier.notifyProposalRejected();
    }
}
