/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.intgr.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.discrete.Binomial;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.LogProbGraphSupplier;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.intgr.IntegerPlaceholderVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import io.improbable.keanu.vertices.intgr.probabilistic.ProbabilisticInteger;
import java.util.Collections;
import java.util.Map;
import java.util.Set;

public class BinomialVertex
extends IntegerVertex
implements ProbabilisticInteger,
SamplableWithManyScalars<IntegerTensor>,
LogProbGraphSupplier {
    private final DoubleVertex p;
    private final IntegerVertex n;
    private static final String P_NAME = "p";
    private static final String N_NAME = "n";

    public BinomialVertex(@LoadShape long[] tensorShape, @LoadVertexParam(value="p") DoubleVertex p, @LoadVertexParam(value="n") IntegerVertex n) {
        super(tensorShape);
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(tensorShape, p.getShape(), n.getShape());
        this.p = p;
        this.n = n;
        this.setParents(p, n);
    }

    public BinomialVertex(long[] tensorShape, double p, IntegerVertex n) {
        this(tensorShape, (DoubleVertex)ConstantVertex.of(p), n);
    }

    public BinomialVertex(long[] tensorShape, DoubleVertex p, int n) {
        this(tensorShape, p, (IntegerVertex)ConstantVertex.of(n));
    }

    public BinomialVertex(long[] tensorShape, double p, int n) {
        this(tensorShape, (DoubleVertex)ConstantVertex.of(p), (IntegerVertex)ConstantVertex.of(n));
    }

    @ExportVertexToPythonBindings
    public BinomialVertex(DoubleVertex p, IntegerVertex n) {
        this(TensorShapeValidation.checkHasOneNonLengthOneShapeOrAllLengthOne(p.getShape(), n.getShape()), p, n);
    }

    public BinomialVertex(double p, IntegerVertex n) {
        this((DoubleVertex)ConstantVertex.of(p), n);
    }

    public BinomialVertex(DoubleVertex p, int n) {
        this(p, (IntegerVertex)ConstantVertex.of(n));
    }

    public BinomialVertex(double p, int n) {
        this((DoubleVertex)ConstantVertex.of(p), (IntegerVertex)ConstantVertex.of(n));
    }

    @Override
    public double logProb(IntegerTensor k) {
        return (Double)Binomial.withParameters((DoubleTensor)this.p.getValue(), (IntegerTensor)this.n.getValue()).logProb(k).sum();
    }

    @Override
    public LogProbGraph logProbGraph() {
        IntegerPlaceholderVertex kPlaceholder = new IntegerPlaceholderVertex(this.getShape());
        DoublePlaceholderVertex pPlaceholder = new DoublePlaceholderVertex(this.p.getShape());
        IntegerPlaceholderVertex nPlaceholder = new IntegerPlaceholderVertex(this.n.getShape());
        return LogProbGraph.builder().input(this, kPlaceholder).input(this.p, pPlaceholder).input(this.n, nPlaceholder).logProbOutput(Binomial.logProbOutput(kPlaceholder, pPlaceholder, nPlaceholder)).build();
    }

    @Override
    public Map<Vertex, DoubleTensor> dLogProb(IntegerTensor value, Set<? extends Vertex> withRespectTo) {
        return Collections.emptyMap();
    }

    @Override
    public IntegerTensor sampleWithShape(long[] shape, KeanuRandom random) {
        return (IntegerTensor)Binomial.withParameters((DoubleTensor)this.p.getValue(), (IntegerTensor)this.n.getValue()).sample(shape, random);
    }

    @SaveVertexParam(value="p")
    public DoubleVertex getP() {
        return this.p;
    }

    @SaveVertexParam(value="n")
    public IntegerVertex getN() {
        return this.n;
    }
}

