/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.MatrixInverseVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.ReshapeVertex;

public class MultivariateGaussian
implements ContinuousDistribution {
    private static final double LOG_2_PI = Math.log(Math.PI * 2);
    private final DoubleTensor mu;
    private final DoubleTensor covariance;

    public static ContinuousDistribution withParameters(DoubleTensor mu, DoubleTensor covariance) {
        return new MultivariateGaussian(mu, covariance);
    }

    private MultivariateGaussian(DoubleTensor mu, DoubleTensor covariance) {
        this.mu = mu;
        this.covariance = covariance;
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(shape, new long[][]{this.mu.getShape()});
        DoubleTensor choleskyCov = (DoubleTensor)this.covariance.choleskyDecomposition();
        DoubleTensor variateSamples = (DoubleTensor)random.nextGaussian(this.mu.getShape()).reshape(this.mu.getLength(), 1L);
        DoubleTensor covTimesVariates = this.isUnivariate() ? choleskyCov.times(variateSamples) : choleskyCov.matrixMultiply(variateSamples);
        return ((DoubleTensor)covTimesVariates.reshape(this.mu.getShape())).plus(this.mu);
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        long dimensions = this.numberOfDimensions();
        double kLog2Pi = (double)dimensions * LOG_2_PI;
        double logCovDet = Math.log((Double)this.covariance.determinant());
        DoubleTensor xMinusMu = (DoubleTensor)x.minus(this.mu).reshape(dimensions, 1L);
        DoubleTensor xMinusMuT = (DoubleTensor)xMinusMu.reshape(1L, dimensions);
        DoubleTensor covInv = (DoubleTensor)this.covariance.matrixInverse();
        double scalar = this.isUnivariate() ? ((Double)covInv.times(xMinusMu).times(xMinusMuT).scalar()).doubleValue() : ((Double)xMinusMuT.matrixMultiply(covInv.matrixMultiply(xMinusMu)).scalar()).doubleValue();
        return DoubleTensor.scalar(-0.5 * (scalar + kLog2Pi + logCovDet));
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        throw new UnsupportedOperationException();
    }

    public static DoubleVertex logProbGraph(DoublePlaceholderVertex x, DoublePlaceholderVertex mu, DoublePlaceholderVertex covariance) {
        long dimensions = MultivariateGaussian.numberOfDimensions(mu.getShape());
        double kLog2Pi = (double)dimensions * LOG_2_PI;
        LogVertex logCovDet = covariance.matrixDeterminant().log();
        ReshapeVertex xMinusMu = x.minus(mu).reshape(dimensions, 1L);
        ReshapeVertex xMinusMuT = xMinusMu.reshape(1L, dimensions);
        MatrixInverseVertex covInv = covariance.matrixInverse();
        DoubleVertex scalar = MultivariateGaussian.isUnivariate(dimensions) ? covInv.times(xMinusMu).times(xMinusMuT) : xMinusMuT.matrixMultiply(covInv.matrixMultiply(xMinusMu));
        return scalar.plus(kLog2Pi).plus(logCovDet).times(-0.5);
    }

    private boolean isUnivariate() {
        return MultivariateGaussian.isUnivariate(this.numberOfDimensions());
    }

    private static boolean isUnivariate(long numberOfDimensions) {
        return numberOfDimensions == 1L;
    }

    private long numberOfDimensions() {
        return MultivariateGaussian.numberOfDimensions(this.mu.getShape());
    }

    private static long numberOfDimensions(long[] muShape) {
        return muShape[0];
    }
}

