package io.improbable.keanu.vertices.bool.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.discrete.Bernoulli;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.LogProbGraphSupplier;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.BooleanPlaceholderVertex;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import java.util.Collections;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:io/improbable/keanu/vertices/bool/probabilistic/BernoulliVertex.class */
public class BernoulliVertex extends BooleanVertex implements ProbabilisticBoolean, SamplableWithManyScalars<BooleanTensor>, LogProbGraphSupplier {
    private final DoubleVertex probTrue;
    private static final String PROBTRUE_NAME = "probTrue";

    /* JADX WARN: Type inference failed for: r1v2, types: [long[], long[][]] */
    public BernoulliVertex(@LoadShape long[] jArr, @LoadVertexParam("probTrue") DoubleVertex doubleVertex) {
        super(jArr);
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(jArr, new long[]{doubleVertex.getShape()});
        this.probTrue = doubleVertex;
        setParents(doubleVertex);
    }

    @ExportVertexToPythonBindings
    public BernoulliVertex(DoubleVertex doubleVertex) {
        this(doubleVertex.getShape(), doubleVertex);
    }

    public BernoulliVertex(double d) {
        this(Tensor.SCALAR_SHAPE, new ConstantDoubleVertex(d));
    }

    public BernoulliVertex(long[] jArr, double d) {
        this(jArr, new ConstantDoubleVertex(d));
    }

    @SaveVertexParam(PROBTRUE_NAME)
    public DoubleVertex getProbTrue() {
        return this.probTrue;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.vertices.Probabilistic
    public double logProb(BooleanTensor booleanTensor) {
        return ((Double) Bernoulli.withParameters(this.probTrue.getValue()).logProb(booleanTensor).sum()).doubleValue();
    }

    @Override // io.improbable.keanu.vertices.LogProbGraphSupplier
    public LogProbGraph logProbGraph() {
        BooleanPlaceholderVertex booleanPlaceholderVertex = new BooleanPlaceholderVertex(getShape());
        DoublePlaceholderVertex doublePlaceholderVertex = new DoublePlaceholderVertex(this.probTrue.getShape());
        return LogProbGraph.builder().input(this, booleanPlaceholderVertex).input(this.probTrue, doublePlaceholderVertex).logProbOutput(Bernoulli.logProbGraph(booleanPlaceholderVertex, doublePlaceholderVertex)).build();
    }

    /* renamed from: dLogProb, reason: avoid collision after fix types in other method */
    public Map<Vertex, DoubleTensor> dLogProb2(BooleanTensor booleanTensor, Set<? extends Vertex> set) {
        if (!this.probTrue.isDifferentiable()) {
            throw new UnsupportedOperationException("The probability of the Bernoulli being true must be differentiable");
        }
        if (!set.contains(this.probTrue)) {
            return Collections.emptyMap();
        }
        return Collections.singletonMap(this.probTrue, Bernoulli.withParameters(this.probTrue.getValue()).dLogProb(booleanTensor));
    }

    @Override // io.improbable.keanu.vertices.SamplableWithShape
    public BooleanTensor sampleWithShape(long[] jArr, KeanuRandom keanuRandom) {
        return Bernoulli.withParameters(this.probTrue.getValue()).sample(jArr, keanuRandom);
    }

    @Override // io.improbable.keanu.vertices.Probabilistic
    public /* bridge */ /* synthetic */ Map dLogProb(BooleanTensor booleanTensor, Set set) {
        return dLogProb2(booleanTensor, (Set<? extends Vertex>) set);
    }
}
