package io.improbable.keanu.vertices.dbl.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.continuous.MultivariateGaussian;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.LogProbGraphSupplier;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/probabilistic/MultivariateGaussianVertex.class */
public class MultivariateGaussianVertex extends DoubleVertex implements Differentiable, ProbabilisticDouble, SamplableWithManyScalars<DoubleTensor>, LogProbGraphSupplier {
    private final DoubleVertex mu;
    private final DoubleVertex covariance;
    private static final String MU_NAME = "mu";
    private static final String COVARIANCE_NAME = "covariance";

    public MultivariateGaussianVertex(@LoadShape long[] jArr, @LoadVertexParam("mu") DoubleVertex doubleVertex, @LoadVertexParam("covariance") DoubleVertex doubleVertex2) {
        super(jArr);
        checkValidMultivariateShape(doubleVertex.getShape(), doubleVertex2.getShape());
        this.mu = doubleVertex;
        this.covariance = doubleVertex2;
        setParents(doubleVertex, doubleVertex2);
    }

    @ExportVertexToPythonBindings
    public MultivariateGaussianVertex(DoubleVertex doubleVertex, DoubleVertex doubleVertex2) {
        this(checkValidMultivariateShape(doubleVertex.getShape(), doubleVertex2.getShape()), doubleVertex, doubleVertex2);
    }

    public MultivariateGaussianVertex(DoubleVertex doubleVertex, double d) {
        this(doubleVertex, ConstantVertex.of(DoubleTensor.eye(doubleVertex.getShape()[0]).times2(d)));
    }

    public MultivariateGaussianVertex(double d, double d2) {
        this(new ConstantDoubleVertex(DoubleTensor.vector(d)), oneByOneMatrix(d2));
    }

    private static DoubleVertex oneByOneMatrix(double d) {
        return new ConstantDoubleVertex((DoubleTensor) DoubleTensor.scalar(d).reshape(Tensor.ONE_BY_ONE_SHAPE));
    }

    @SaveVertexParam(MU_NAME)
    public DoubleVertex getMu() {
        return this.mu;
    }

    @SaveVertexParam(COVARIANCE_NAME)
    public DoubleVertex getCovariance() {
        return this.covariance;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.vertices.Probabilistic
    public double logProb(DoubleTensor doubleTensor) {
        return ((Double) MultivariateGaussian.withParameters(this.mu.getValue(), this.covariance.getValue()).logProb(doubleTensor).scalar()).doubleValue();
    }

    @Override // io.improbable.keanu.vertices.LogProbGraphSupplier
    public LogProbGraph logProbGraph() {
        DoublePlaceholderVertex doublePlaceholderVertex = new DoublePlaceholderVertex(getShape());
        DoublePlaceholderVertex doublePlaceholderVertex2 = new DoublePlaceholderVertex(this.mu.getShape());
        DoublePlaceholderVertex doublePlaceholderVertex3 = new DoublePlaceholderVertex(this.covariance.getShape());
        return LogProbGraph.builder().input(this, doublePlaceholderVertex).input(this.mu, doublePlaceholderVertex2).input(this.covariance, doublePlaceholderVertex3).logProbOutput(MultivariateGaussian.logProbGraph(doublePlaceholderVertex, doublePlaceholderVertex2, doublePlaceholderVertex3)).build();
    }

    /* renamed from: dLogProb, reason: avoid collision after fix types in other method */
    public Map<Vertex, DoubleTensor> dLogProb2(DoubleTensor doubleTensor, Set<? extends Vertex> set) {
        throw new UnsupportedOperationException();
    }

    @Override // io.improbable.keanu.vertices.SamplableWithShape
    public DoubleTensor sampleWithShape(long[] jArr, KeanuRandom keanuRandom) {
        return MultivariateGaussian.withParameters(this.mu.getValue(), this.covariance.getValue()).sample(jArr, keanuRandom);
    }

    private static long[] checkValidMultivariateShape(long[] jArr, long[] jArr2) {
        if (jArr2.length != 2) {
            throw new IllegalArgumentException("Covariance must be matrix but was rank " + jArr2.length);
        }
        if (jArr.length != 1) {
            throw new IllegalArgumentException("Mu must be vector but was rank " + jArr.length);
        }
        if (jArr2[0] != jArr2[1]) {
            throw new IllegalArgumentException("Covariance matrix must be square. Given shape: " + Arrays.toString(jArr2));
        }
        if (jArr[0] != jArr2[0]) {
            throw new IllegalArgumentException("Dimension 0 of mu must equal dimension 0 of covariance. Given: mu " + jArr[0] + ", covariance " + jArr2[0]);
        }
        return jArr;
    }

    @Override // io.improbable.keanu.vertices.Probabilistic
    public /* bridge */ /* synthetic */ Map dLogProb(DoubleTensor doubleTensor, Set set) {
        return dLogProb2(doubleTensor, (Set<? extends Vertex>) set);
    }
}
