package io.improbable.keanu.distributions.discrete;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.DiscreteDistribution;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import io.improbable.keanu.vertices.intgr.IntegerPlaceholderVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;

/* loaded from: input_file:io/improbable/keanu/distributions/discrete/Geometric.class */
public class Geometric implements DiscreteDistribution {
    private final DoubleTensor p;

    public static DiscreteDistribution withParameters(DoubleTensor doubleTensor) {
        return new Geometric(doubleTensor);
    }

    private Geometric(DoubleTensor doubleTensor) {
        this.p = doubleTensor;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.distributions.Distribution
    public IntegerTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        return (IntegerTensor) ((DoubleTensor) keanuRandom.nextDouble(jArr).logInPlace().divInPlace(((DoubleTensor) ((DoubleTensor) this.p.unaryMinus()).plusInPlace((DoubleTensor) Double.valueOf(1.0d))).logInPlace())).floorInPlace().toInteger().plusInPlace((IntegerTensor) 1);
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(IntegerTensor integerTensor) {
        return !checkParameterIsValid() ? DoubleTensor.create(Double.NEGATIVE_INFINITY, integerTensor.getShape()) : calculateLogProb(integerTensor);
    }

    public static DoubleVertex logProbOutput(IntegerPlaceholderVertex integerPlaceholderVertex, DoublePlaceholderVertex doublePlaceholderVertex) {
        ConstantDoubleVertex of = ConstantVertex.of(DoubleTensor.zeros(integerPlaceholderVertex.getShape()));
        ConstantDoubleVertex of2 = ConstantVertex.of(DoubleTensor.ones(integerPlaceholderVertex.getShape()));
        return calculateLogProb(integerPlaceholderVertex, doublePlaceholderVertex).setWithMask(doublePlaceholderVertex.toGreaterThanMask(of).times(doublePlaceholderVertex.toLessThanMask(of2)).unaryMinus().plus((DoubleVertex) of2), Double.NEGATIVE_INFINITY);
    }

    private DoubleTensor calculateLogProb(IntegerTensor integerTensor) {
        return setProbToZeroForInvalidK(integerTensor, (DoubleTensor) ((DoubleTensor) ((DoubleTensor) integerTensor.toDouble().minusInPlace((DoubleTensor) Double.valueOf(1.0d))).timesInPlace(((DoubleTensor) ((DoubleTensor) this.p.unaryMinus()).plusInPlace((DoubleTensor) Double.valueOf(1.0d))).logInPlace())).plusInPlace(this.p.log2()));
    }

    private static DoubleVertex calculateLogProb(IntegerVertex integerVertex, DoubleVertex doubleVertex) {
        return setProbToZeroForInvalidK(integerVertex, integerVertex.toDouble().minus2(1.0d).times((DoubleVertex) doubleVertex.unaryMinus().plus2(1.0d).log2()).plus((DoubleVertex) doubleVertex.log2()));
    }

    private DoubleTensor setProbToZeroForInvalidK(IntegerTensor integerTensor, DoubleTensor doubleTensor) {
        return (DoubleTensor) doubleTensor.setWithMaskInPlace(((IntegerTensor) integerTensor.lessThanMask(IntegerTensor.create(1, integerTensor.getShape()))).toDouble(), Double.valueOf(Double.NEGATIVE_INFINITY));
    }

    private static DoubleVertex setProbToZeroForInvalidK(IntegerVertex integerVertex, DoubleVertex doubleVertex) {
        return doubleVertex.setWithMask(integerVertex.toDouble().toLessThanMask(1.0d), Double.NEGATIVE_INFINITY);
    }

    private boolean checkParameterIsValid() {
        return this.p.greaterThan((DoubleTensor) Double.valueOf(0.0d)).allTrue() && this.p.lessThan((DoubleTensor) Double.valueOf(1.0d)).allTrue();
    }
}
