/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.intgr.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.discrete.Geometric;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.LogProbGraphSupplier;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.intgr.IntegerPlaceholderVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import io.improbable.keanu.vertices.intgr.probabilistic.ProbabilisticInteger;
import java.util.Collections;
import java.util.Map;
import java.util.Set;

public class GeometricVertex
extends IntegerVertex
implements ProbabilisticInteger,
SamplableWithManyScalars<IntegerTensor>,
LogProbGraphSupplier {
    private final DoubleVertex p;
    private static final String P_NAME = "p";

    public GeometricVertex(@LoadShape long[] tensorShape, @LoadVertexParam(value="p") DoubleVertex p) {
        super(tensorShape);
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(tensorShape, new long[][]{p.getShape()});
        this.p = p;
        this.setParents(p);
    }

    public GeometricVertex(long[] tensorShape, double p) {
        this(tensorShape, ConstantVertex.of(p));
    }

    @ExportVertexToPythonBindings
    public GeometricVertex(DoubleVertex p) {
        this(p.getShape(), p);
    }

    public GeometricVertex(double p) {
        this(ConstantVertex.of(p));
    }

    @Override
    public IntegerTensor sampleWithShape(long[] shape, KeanuRandom random) {
        return (IntegerTensor)Geometric.withParameters((DoubleTensor)this.p.getValue()).sample(shape, random);
    }

    @Override
    public double logProb(IntegerTensor value) {
        return (Double)Geometric.withParameters((DoubleTensor)this.p.getValue()).logProb(value).sum();
    }

    @Override
    public LogProbGraph logProbGraph() {
        IntegerPlaceholderVertex valuePlaceholder = new IntegerPlaceholderVertex(this.getShape());
        DoublePlaceholderVertex pPlaceholder = new DoublePlaceholderVertex(this.p.getShape());
        return LogProbGraph.builder().input(this, valuePlaceholder).input(this.p, pPlaceholder).logProbOutput(Geometric.logProbOutput(valuePlaceholder, pPlaceholder)).build();
    }

    @Override
    public Map<Vertex, DoubleTensor> dLogProb(IntegerTensor atValue, Set<? extends Vertex> withRespectTo) {
        return Collections.emptyMap();
    }

    @SaveVertexParam(value="p")
    public DoubleVertex getP() {
        return this.p;
    }
}

