package io.improbable.keanu.distributions.discrete;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.DiscreteDistribution;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.intgr.IntegerPlaceholderVertex;

/* loaded from: input_file:io/improbable/keanu/distributions/discrete/Poisson.class */
public class Poisson implements DiscreteDistribution {
    private final DoubleTensor mu;

    public static DiscreteDistribution withParameters(DoubleTensor doubleTensor) {
        return new Poisson(doubleTensor);
    }

    private Poisson(DoubleTensor doubleTensor) {
        this.mu = doubleTensor;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.distributions.Distribution
    public IntegerTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        Tensor.FlattenedView<N> flattenedView = this.mu.getFlattenedView();
        int lengthAsInt = TensorShape.getLengthAsInt(jArr);
        int[] iArr = new int[lengthAsInt];
        for (int i = 0; i < lengthAsInt; i++) {
            iArr[i] = sample(((Double) flattenedView.getOrScalar(i)).doubleValue(), keanuRandom);
        }
        return IntegerTensor.create(iArr, jArr);
    }

    private static int sample(double d, KeanuRandom keanuRandom) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid value for mu: " + d);
        }
        double d2 = d;
        int i = 0;
        double d3 = 1.0d;
        do {
            i++;
            d3 *= keanuRandom.nextDoubleNonZero();
            while (d3 < 1.0d && d2 > 0.0d) {
                if (d2 > 500.0d) {
                    d3 *= Math.exp(500.0d);
                    d2 -= 500.0d;
                } else {
                    d3 *= Math.exp(d2);
                    d2 = 0.0d;
                }
            }
        } while (d3 > 1.0d);
        return i - 1;
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(IntegerTensor integerTensor) {
        DoubleTensor doubleTensor = integerTensor.toDouble();
        return (DoubleTensor) ((DoubleTensor) ((DoubleTensor) doubleTensor.timesInPlace(this.mu.log2())).minusInPlace(this.mu)).minusInPlace(doubleTensor.plus2(1.0d).logGammaInPlace());
    }

    public static DoubleVertex logProbOutput(IntegerPlaceholderVertex integerPlaceholderVertex, DoublePlaceholderVertex doublePlaceholderVertex) {
        DoubleVertex doubleVertex = integerPlaceholderVertex.toDouble();
        return doubleVertex.times((DoubleVertex) doublePlaceholderVertex.log2()).minus((DoubleVertex) doublePlaceholderVertex).minus((DoubleVertex) doubleVertex.plus2(1.0d).logGamma());
    }
}
