package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.MatrixInverseVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.ReshapeVertex;

/* loaded from: input_file:io/improbable/keanu/distributions/continuous/MultivariateGaussian.class */
public class MultivariateGaussian implements ContinuousDistribution {
    private static final double LOG_2_PI = Math.log(6.283185307179586d);
    private final DoubleTensor mu;
    private final DoubleTensor covariance;

    public static ContinuousDistribution withParameters(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        return new MultivariateGaussian(doubleTensor, doubleTensor2);
    }

    private MultivariateGaussian(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        this.mu = doubleTensor;
        this.covariance = doubleTensor2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Type inference failed for: r1v1, types: [long[], long[][]] */
    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(jArr, new long[]{this.mu.getShape()});
        DoubleTensor choleskyDecomposition = this.covariance.choleskyDecomposition();
        DoubleTensor doubleTensor = (DoubleTensor) keanuRandom.nextGaussian(this.mu.getShape()).reshape(this.mu.getLength(), 1);
        return (DoubleTensor) ((DoubleTensor) (isUnivariate() ? (DoubleTensor) choleskyDecomposition.times(doubleTensor) : (DoubleTensor) choleskyDecomposition.matrixMultiply(doubleTensor)).reshape(this.mu.getShape())).plus(this.mu);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(DoubleTensor doubleTensor) {
        long numberOfDimensions = numberOfDimensions();
        double d = numberOfDimensions * LOG_2_PI;
        double log = Math.log(this.covariance.determinant().doubleValue());
        DoubleTensor doubleTensor2 = (DoubleTensor) ((DoubleTensor) doubleTensor.minus(this.mu)).reshape(numberOfDimensions, 1);
        DoubleTensor doubleTensor3 = (DoubleTensor) doubleTensor2.reshape(1, numberOfDimensions);
        DoubleTensor matrixInverse = this.covariance.matrixInverse();
        return DoubleTensor.scalar((-0.5d) * ((isUnivariate() ? ((Double) ((DoubleTensor) ((DoubleTensor) matrixInverse.times(doubleTensor2)).times(doubleTensor3)).scalar()).doubleValue() : ((Double) ((DoubleTensor) doubleTensor3.matrixMultiply(matrixInverse.matrixMultiply(doubleTensor2))).scalar()).doubleValue()) + d + log));
    }

    @Override // io.improbable.keanu.distributions.ContinuousDistribution
    public Diffs dLogProb(DoubleTensor doubleTensor) {
        throw new UnsupportedOperationException();
    }

    public static DoubleVertex logProbGraph(DoublePlaceholderVertex doublePlaceholderVertex, DoublePlaceholderVertex doublePlaceholderVertex2, DoublePlaceholderVertex doublePlaceholderVertex3) {
        long numberOfDimensions = numberOfDimensions(doublePlaceholderVertex2.getShape());
        double d = numberOfDimensions * LOG_2_PI;
        LogVertex log = doublePlaceholderVertex3.matrixDeterminant().log2();
        ReshapeVertex reshape = doublePlaceholderVertex.minus((DoubleVertex) doublePlaceholderVertex2).reshape(numberOfDimensions, 1);
        ReshapeVertex reshape2 = reshape.reshape(1, numberOfDimensions);
        MatrixInverseVertex matrixInverse = doublePlaceholderVertex3.matrixInverse();
        return (isUnivariate(numberOfDimensions) ? matrixInverse.times((DoubleVertex) reshape).times((DoubleVertex) reshape2) : reshape2.matrixMultiply(matrixInverse.matrixMultiply(reshape))).plus2(d).plus((DoubleVertex) log).times2(-0.5d);
    }

    private boolean isUnivariate() {
        return isUnivariate(numberOfDimensions());
    }

    private static boolean isUnivariate(long j) {
        return j == 1;
    }

    private long numberOfDimensions() {
        return numberOfDimensions(this.mu.getShape());
    }

    private static long numberOfDimensions(long[] jArr) {
        return jArr[0];
    }
}
