/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.continuous.MultivariateGaussian;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.LogProbGraphSupplier;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import io.improbable.keanu.vertices.dbl.probabilistic.ProbabilisticDouble;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;

public class MultivariateGaussianVertex
extends DoubleVertex
implements Differentiable,
ProbabilisticDouble,
SamplableWithManyScalars<DoubleTensor>,
LogProbGraphSupplier {
    private final DoubleVertex mu;
    private final DoubleVertex covariance;
    private static final String MU_NAME = "mu";
    private static final String COVARIANCE_NAME = "covariance";

    public MultivariateGaussianVertex(@LoadShape long[] shape, @LoadVertexParam(value="mu") DoubleVertex mu, @LoadVertexParam(value="covariance") DoubleVertex covariance) {
        super(shape);
        MultivariateGaussianVertex.checkValidMultivariateShape(mu.getShape(), covariance.getShape());
        this.mu = mu;
        this.covariance = covariance;
        this.setParents(mu, covariance);
    }

    @ExportVertexToPythonBindings
    public MultivariateGaussianVertex(DoubleVertex mu, DoubleVertex covariance) {
        this(MultivariateGaussianVertex.checkValidMultivariateShape(mu.getShape(), covariance.getShape()), mu, covariance);
    }

    public MultivariateGaussianVertex(DoubleVertex mu, double covariance) {
        this(mu, ConstantVertex.of(DoubleTensor.eye(mu.getShape()[0]).times(covariance)));
    }

    public MultivariateGaussianVertex(double mu, double covariance) {
        this((DoubleVertex)new ConstantDoubleVertex(DoubleTensor.vector(mu)), MultivariateGaussianVertex.oneByOneMatrix(covariance));
    }

    private static DoubleVertex oneByOneMatrix(double value) {
        return new ConstantDoubleVertex((DoubleTensor)DoubleTensor.scalar(value).reshape(Tensor.ONE_BY_ONE_SHAPE));
    }

    @SaveVertexParam(value="mu")
    public DoubleVertex getMu() {
        return this.mu;
    }

    @SaveVertexParam(value="covariance")
    public DoubleVertex getCovariance() {
        return this.covariance;
    }

    @Override
    public double logProb(DoubleTensor value) {
        DoubleTensor muValues = (DoubleTensor)this.mu.getValue();
        DoubleTensor covarianceValues = (DoubleTensor)this.covariance.getValue();
        return (Double)MultivariateGaussian.withParameters(muValues, covarianceValues).logProb(value).scalar();
    }

    @Override
    public LogProbGraph logProbGraph() {
        DoublePlaceholderVertex xPlaceholder = new DoublePlaceholderVertex(this.getShape());
        DoublePlaceholderVertex muPlaceholder = new DoublePlaceholderVertex(this.mu.getShape());
        DoublePlaceholderVertex covPlaceholder = new DoublePlaceholderVertex(this.covariance.getShape());
        return LogProbGraph.builder().input(this, xPlaceholder).input(this.mu, muPlaceholder).input(this.covariance, covPlaceholder).logProbOutput(MultivariateGaussian.logProbGraph(xPlaceholder, muPlaceholder, covPlaceholder)).build();
    }

    @Override
    public Map<Vertex, DoubleTensor> dLogProb(DoubleTensor value, Set<? extends Vertex> withRespectTo) {
        throw new UnsupportedOperationException();
    }

    @Override
    public DoubleTensor sampleWithShape(long[] shape, KeanuRandom random) {
        return (DoubleTensor)MultivariateGaussian.withParameters((DoubleTensor)this.mu.getValue(), (DoubleTensor)this.covariance.getValue()).sample(shape, random);
    }

    private static long[] checkValidMultivariateShape(long[] muShape, long[] covarianceShape) {
        if (covarianceShape.length != 2) {
            throw new IllegalArgumentException("Covariance must be matrix but was rank " + covarianceShape.length);
        }
        if (muShape.length != 1) {
            throw new IllegalArgumentException("Mu must be vector but was rank " + muShape.length);
        }
        if (covarianceShape[0] != covarianceShape[1]) {
            throw new IllegalArgumentException("Covariance matrix must be square. Given shape: " + Arrays.toString(covarianceShape));
        }
        if (muShape[0] != covarianceShape[0]) {
            throw new IllegalArgumentException("Dimension 0 of mu must equal dimension 0 of covariance. Given: mu " + muShape[0] + ", covariance " + covarianceShape[0]);
        }
        return muShape;
    }
}

