/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.continuous.InverseGamma;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.LogProbGraphSupplier;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import io.improbable.keanu.vertices.dbl.probabilistic.ProbabilisticDouble;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

public class InverseGammaVertex
extends DoubleVertex
implements Differentiable,
ProbabilisticDouble,
SamplableWithManyScalars<DoubleTensor>,
LogProbGraphSupplier {
    private final DoubleVertex alpha;
    private final DoubleVertex beta;
    private static final String ALPHA_NAME = "alpha";
    private static final String BETA_NAME = "beta";

    public InverseGammaVertex(@LoadShape long[] tensorShape, @LoadVertexParam(value="alpha") DoubleVertex alpha, @LoadVertexParam(value="beta") DoubleVertex beta) {
        super(tensorShape);
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(tensorShape, alpha.getShape(), beta.getShape());
        this.alpha = alpha;
        this.beta = beta;
        this.setParents(alpha, beta);
    }

    @ExportVertexToPythonBindings
    public InverseGammaVertex(DoubleVertex alpha, DoubleVertex beta) {
        this(TensorShapeValidation.checkHasOneNonLengthOneShapeOrAllLengthOne(alpha.getShape(), beta.getShape()), alpha, beta);
    }

    public InverseGammaVertex(DoubleVertex alpha, double beta) {
        this(alpha, (DoubleVertex)new ConstantDoubleVertex(beta));
    }

    public InverseGammaVertex(double alpha, DoubleVertex beta) {
        this((DoubleVertex)new ConstantDoubleVertex(alpha), beta);
    }

    public InverseGammaVertex(double alpha, double beta) {
        this((DoubleVertex)new ConstantDoubleVertex(alpha), (DoubleVertex)new ConstantDoubleVertex(beta));
    }

    public InverseGammaVertex(long[] tensorShape, DoubleVertex alpha, double beta) {
        this(tensorShape, alpha, (DoubleVertex)new ConstantDoubleVertex(beta));
    }

    public InverseGammaVertex(long[] tensorShape, double alpha, DoubleVertex beta) {
        this(tensorShape, (DoubleVertex)new ConstantDoubleVertex(alpha), beta);
    }

    public InverseGammaVertex(long[] tensorShape, double alpha, double beta) {
        this(tensorShape, (DoubleVertex)new ConstantDoubleVertex(alpha), (DoubleVertex)new ConstantDoubleVertex(beta));
    }

    @SaveVertexParam(value="alpha")
    public DoubleVertex getAlpha() {
        return this.alpha;
    }

    @SaveVertexParam(value="beta")
    public DoubleVertex getBeta() {
        return this.beta;
    }

    @Override
    public double logProb(DoubleTensor value) {
        DoubleTensor alphaValues = (DoubleTensor)this.alpha.getValue();
        DoubleTensor betaValues = (DoubleTensor)this.beta.getValue();
        DoubleTensor logPdfs = InverseGamma.withParameters(alphaValues, betaValues).logProb(value);
        return (Double)logPdfs.sum();
    }

    @Override
    public LogProbGraph logProbGraph() {
        DoublePlaceholderVertex xPlaceholder = new DoublePlaceholderVertex(this.getShape());
        DoublePlaceholderVertex alphaPlaceholder = new DoublePlaceholderVertex(this.alpha.getShape());
        DoublePlaceholderVertex betaPlaceholder = new DoublePlaceholderVertex(this.beta.getShape());
        return LogProbGraph.builder().input(this, xPlaceholder).input(this.alpha, alphaPlaceholder).input(this.beta, betaPlaceholder).logProbOutput(InverseGamma.logProbOutput(xPlaceholder, alphaPlaceholder, betaPlaceholder)).build();
    }

    @Override
    public Map<Vertex, DoubleTensor> dLogProb(DoubleTensor value, Set<? extends Vertex> withRespectTo) {
        Diffs dlnP = InverseGamma.withParameters((DoubleTensor)this.alpha.getValue(), (DoubleTensor)this.beta.getValue()).dLogProb(value);
        HashMap<Vertex, DoubleTensor> dLogProbWrtParameters = new HashMap<Vertex, DoubleTensor>();
        if (withRespectTo.contains(this.alpha)) {
            dLogProbWrtParameters.put(this.alpha, dlnP.get(Diffs.A).getValue());
        }
        if (withRespectTo.contains(this.beta)) {
            dLogProbWrtParameters.put(this.beta, dlnP.get(Diffs.B).getValue());
        }
        if (withRespectTo.contains(this)) {
            dLogProbWrtParameters.put(this, dlnP.get(Diffs.X).getValue());
        }
        return dLogProbWrtParameters;
    }

    @Override
    public DoubleTensor sampleWithShape(long[] shape, KeanuRandom random) {
        return (DoubleTensor)InverseGamma.withParameters((DoubleTensor)this.alpha.getValue(), (DoubleTensor)this.beta.getValue()).sample(shape, random);
    }
}

