package io.improbable.keanu.distributions.discrete;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.Distribution;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.bool.BooleanPlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.MinVertex;
import io.improbable.keanu.vertices.generic.nonprobabilistic.If;

/* loaded from: input_file:io/improbable/keanu/distributions/discrete/Bernoulli.class */
public class Bernoulli implements Distribution<BooleanTensor> {
    private final DoubleTensor probTrue;

    public static Bernoulli withParameters(DoubleTensor doubleTensor) {
        return new Bernoulli(doubleTensor);
    }

    private Bernoulli(DoubleTensor doubleTensor) {
        this.probTrue = doubleTensor;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.distributions.Distribution
    public BooleanTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        return keanuRandom.nextDouble(jArr).lessThan(this.probTrue);
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(BooleanTensor booleanTensor) {
        DoubleTensor doubleTensor = (DoubleTensor) this.probTrue.clamp(DoubleTensor.scalar(0.0d), DoubleTensor.scalar(1.0d));
        return booleanTensor.doubleWhere(doubleTensor, (DoubleTensor) ((DoubleTensor) doubleTensor.unaryMinus()).plusInPlace((DoubleTensor) Double.valueOf(1.0d))).logInPlace();
    }

    public static DoubleVertex logProbGraph(BooleanPlaceholderVertex booleanPlaceholderVertex, DoublePlaceholderVertex doublePlaceholderVertex) {
        ConstantDoubleVertex of = ConstantVertex.of(DoubleTensor.zeros(booleanPlaceholderVertex.getShape()));
        MinVertex min = DoubleVertex.min(DoubleVertex.max(doublePlaceholderVertex, of), ConstantVertex.of(DoubleTensor.ones(booleanPlaceholderVertex.getShape())));
        return If.isTrue(booleanPlaceholderVertex).then((DoubleVertex) min).orElse(min.unaryMinus().plus2(1.0d)).log2();
    }

    public DoubleTensor dLogProb(BooleanTensor booleanTensor) {
        DoubleTensor doubleTensor = (DoubleTensor) ((DoubleTensor) this.probTrue.greaterThanMask(DoubleTensor.scalar(1.0d))).plusInPlace((DoubleTensor) this.probTrue.lessThanOrEqualToMask(DoubleTensor.scalar(0.0d)));
        return booleanTensor.doubleWhere((DoubleTensor) this.probTrue.reciprocal().setWithMaskInPlace(doubleTensor, Double.valueOf(0.0d)), (DoubleTensor) this.probTrue.minus2(1.0d).reciprocalInPlace().setWithMaskInPlace(doubleTensor, Double.valueOf(0.0d)));
    }
}
