/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.bool.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.discrete.Bernoulli;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.LogProbGraphSupplier;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.BooleanPlaceholderVertex;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.bool.probabilistic.ProbabilisticBoolean;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import java.util.Collections;
import java.util.Map;
import java.util.Set;

public class BernoulliVertex
extends BooleanVertex
implements ProbabilisticBoolean,
SamplableWithManyScalars<BooleanTensor>,
LogProbGraphSupplier {
    private final DoubleVertex probTrue;
    private static final String PROBTRUE_NAME = "probTrue";

    public BernoulliVertex(@LoadShape long[] shape, @LoadVertexParam(value="probTrue") DoubleVertex probTrue) {
        super(shape);
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(shape, new long[][]{probTrue.getShape()});
        this.probTrue = probTrue;
        this.setParents(probTrue);
    }

    @ExportVertexToPythonBindings
    public BernoulliVertex(DoubleVertex probTrue) {
        this(probTrue.getShape(), probTrue);
    }

    public BernoulliVertex(double probTrue) {
        this(Tensor.SCALAR_SHAPE, new ConstantDoubleVertex(probTrue));
    }

    public BernoulliVertex(long[] shape, double probTrue) {
        this(shape, new ConstantDoubleVertex(probTrue));
    }

    @SaveVertexParam(value="probTrue")
    public DoubleVertex getProbTrue() {
        return this.probTrue;
    }

    @Override
    public double logProb(BooleanTensor value) {
        return (Double)Bernoulli.withParameters((DoubleTensor)this.probTrue.getValue()).logProb(value).sum();
    }

    @Override
    public LogProbGraph logProbGraph() {
        BooleanPlaceholderVertex valuePlaceholder = new BooleanPlaceholderVertex(this.getShape());
        DoublePlaceholderVertex probTruePlaceholder = new DoublePlaceholderVertex(this.probTrue.getShape());
        return LogProbGraph.builder().input(this, valuePlaceholder).input(this.probTrue, probTruePlaceholder).logProbOutput(Bernoulli.logProbGraph(valuePlaceholder, probTruePlaceholder)).build();
    }

    @Override
    public Map<Vertex, DoubleTensor> dLogProb(BooleanTensor value, Set<? extends Vertex> withRespectTo) {
        if (!this.probTrue.isDifferentiable()) {
            throw new UnsupportedOperationException("The probability of the Bernoulli being true must be differentiable");
        }
        if (withRespectTo.contains(this.probTrue)) {
            DoubleTensor dLogPdp = Bernoulli.withParameters((DoubleTensor)this.probTrue.getValue()).dLogProb(value);
            return Collections.singletonMap(this.probTrue, dLogPdp);
        }
        return Collections.emptyMap();
    }

    @Override
    public BooleanTensor sampleWithShape(long[] shape, KeanuRandom random) {
        return Bernoulli.withParameters((DoubleTensor)this.probTrue.getValue()).sample(shape, random);
    }
}

