/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.continuous.StudentT;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.LogProbGraphSupplier;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.probabilistic.ProbabilisticDouble;
import io.improbable.keanu.vertices.intgr.IntegerPlaceholderVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import io.improbable.keanu.vertices.intgr.nonprobabilistic.ConstantIntegerVertex;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

public class StudentTVertex
extends DoubleVertex
implements Differentiable,
ProbabilisticDouble,
SamplableWithManyScalars<DoubleTensor>,
LogProbGraphSupplier {
    private final IntegerVertex v;
    private static final String V_NAME = "v";

    public StudentTVertex(@LoadShape long[] tensorShape, @LoadVertexParam(value="v") IntegerVertex v) {
        super(tensorShape);
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(tensorShape, new long[][]{v.getShape()});
        this.v = v;
        this.setParents(v);
    }

    public StudentTVertex(long[] tensorShape, int v) {
        this(tensorShape, new ConstantIntegerVertex(v));
    }

    @ExportVertexToPythonBindings
    public StudentTVertex(IntegerVertex v) {
        this(v.getShape(), v);
    }

    public StudentTVertex(int v) {
        this(Tensor.SCALAR_SHAPE, new ConstantIntegerVertex(v));
    }

    @SaveVertexParam(value="v")
    public IntegerVertex getV() {
        return this.v;
    }

    @Override
    public double logProb(DoubleTensor t) {
        return (Double)StudentT.withParameters((IntegerTensor)this.v.getValue()).logProb(t).sum();
    }

    @Override
    public LogProbGraph logProbGraph() {
        DoublePlaceholderVertex xPlaceholder = new DoublePlaceholderVertex(this.getShape());
        IntegerPlaceholderVertex vPlaceholder = new IntegerPlaceholderVertex(this.v.getShape());
        return LogProbGraph.builder().input(this, xPlaceholder).input(this.v, vPlaceholder).logProbOutput(StudentT.logProbOutput(xPlaceholder, vPlaceholder)).build();
    }

    @Override
    public Map<Vertex, DoubleTensor> dLogProb(DoubleTensor t, Set<? extends Vertex> withRespect) {
        HashMap<Vertex, DoubleTensor> m = new HashMap<Vertex, DoubleTensor>();
        if (withRespect.contains(this)) {
            Diffs diff = StudentT.withParameters((IntegerTensor)this.v.getValue()).dLogProb(t);
            m.put(this, diff.get(Diffs.T).getValue());
        }
        return m;
    }

    @Override
    public DoubleTensor sampleWithShape(long[] shape, KeanuRandom random) {
        return (DoubleTensor)StudentT.withParameters((IntegerTensor)this.v.getValue()).sample(shape, random);
    }
}

