/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.continuous.LogNormal;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.LogProbGraphSupplier;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.probabilistic.ProbabilisticDouble;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

public class LogNormalVertex
extends DoubleVertex
implements Differentiable,
ProbabilisticDouble,
SamplableWithManyScalars<DoubleTensor>,
LogProbGraphSupplier {
    private final DoubleVertex mu;
    private final DoubleVertex sigma;
    private static final String MU_NAME = "mu";
    private static final String SIGMA_NAME = "sigma";

    public LogNormalVertex(@LoadShape long[] tensorShape, @LoadVertexParam(value="mu") DoubleVertex mu, @LoadVertexParam(value="sigma") DoubleVertex sigma) {
        super(tensorShape);
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(tensorShape, mu.getShape(), sigma.getShape());
        this.mu = mu;
        this.sigma = sigma;
        this.setParents(mu, sigma);
    }

    public LogNormalVertex(long[] tensorShape, DoubleVertex mu, double sigma) {
        this(tensorShape, mu, (DoubleVertex)ConstantVertex.of(sigma));
    }

    public LogNormalVertex(long[] tensorShape, double mu, DoubleVertex sigma) {
        this(tensorShape, (DoubleVertex)ConstantVertex.of(mu), sigma);
    }

    public LogNormalVertex(long[] tensorShape, double mu, double sigma) {
        this(tensorShape, (DoubleVertex)ConstantVertex.of(mu), (DoubleVertex)ConstantVertex.of(sigma));
    }

    @ExportVertexToPythonBindings
    public LogNormalVertex(DoubleVertex mu, DoubleVertex sigma) {
        this(TensorShapeValidation.checkHasOneNonLengthOneShapeOrAllLengthOne(mu.getShape(), sigma.getShape()), mu, sigma);
    }

    public LogNormalVertex(double mu, DoubleVertex sigma) {
        this((DoubleVertex)ConstantVertex.of(mu), sigma);
    }

    public LogNormalVertex(DoubleVertex mu, double sigma) {
        this(mu, (DoubleVertex)ConstantVertex.of(sigma));
    }

    public LogNormalVertex(double mu, double sigma) {
        this((DoubleVertex)ConstantVertex.of(mu), (DoubleVertex)ConstantVertex.of(sigma));
    }

    @SaveVertexParam(value="mu")
    public DoubleVertex getMu() {
        return this.mu;
    }

    @SaveVertexParam(value="sigma")
    public DoubleVertex getSigma() {
        return this.sigma;
    }

    @Override
    public double logProb(DoubleTensor value) {
        DoubleTensor muValues = (DoubleTensor)this.mu.getValue();
        DoubleTensor sigmaValues = (DoubleTensor)this.sigma.getValue();
        DoubleTensor logPdfs = LogNormal.withParameters(muValues, sigmaValues).logProb(value);
        return (Double)logPdfs.sum();
    }

    @Override
    public LogProbGraph logProbGraph() {
        DoublePlaceholderVertex xPlaceholder = new DoublePlaceholderVertex(this.getShape());
        DoublePlaceholderVertex muPlaceholder = new DoublePlaceholderVertex(this.mu.getShape());
        DoublePlaceholderVertex sigmaPlaceholder = new DoublePlaceholderVertex(this.sigma.getShape());
        return LogProbGraph.builder().input(this, xPlaceholder).input(this.mu, muPlaceholder).input(this.sigma, sigmaPlaceholder).logProbOutput(LogNormal.logProbOutput(xPlaceholder, muPlaceholder, sigmaPlaceholder)).build();
    }

    @Override
    public Map<Vertex, DoubleTensor> dLogProb(DoubleTensor value, Set<? extends Vertex> withRespectTo) {
        Diffs dlnP = LogNormal.withParameters((DoubleTensor)this.mu.getValue(), (DoubleTensor)this.sigma.getValue()).dLogProb(value);
        HashMap<Vertex, DoubleTensor> dLogProbWrtParameters = new HashMap<Vertex, DoubleTensor>();
        if (withRespectTo.contains(this.mu)) {
            dLogProbWrtParameters.put(this.mu, dlnP.get(Diffs.MU).getValue());
        }
        if (withRespectTo.contains(this.sigma)) {
            dLogProbWrtParameters.put(this.sigma, dlnP.get(Diffs.SIGMA).getValue());
        }
        if (withRespectTo.contains(this)) {
            dLogProbWrtParameters.put(this, dlnP.get(Diffs.X).getValue());
        }
        return dLogProbWrtParameters;
    }

    @Override
    public DoubleTensor sampleWithShape(long[] shape, KeanuRandom random) {
        return (DoubleTensor)LogNormal.withParameters((DoubleTensor)this.mu.getValue(), (DoubleTensor)this.sigma.getValue()).sample(shape, random);
    }
}

