package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DivisionVertex;

/* loaded from: input_file:io/improbable/keanu/distributions/continuous/Gamma.class */
public class Gamma implements ContinuousDistribution {
    private static final double M_E = 0.5772156649015329d;
    private final DoubleTensor theta;
    private final DoubleTensor k;

    public static ContinuousDistribution withParameters(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        return new Gamma(doubleTensor, doubleTensor2);
    }

    private Gamma(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        this.theta = doubleTensor;
        this.k = doubleTensor2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        Tensor.FlattenedView<N> flattenedView = this.theta.getFlattenedView();
        Tensor.FlattenedView<N> flattenedView2 = this.k.getFlattenedView();
        int lengthAsInt = TensorShape.getLengthAsInt(jArr);
        double[] dArr = new double[lengthAsInt];
        for (int i = 0; i < lengthAsInt; i++) {
            dArr[i] = sample(((Double) flattenedView.getOrScalar(i)).doubleValue(), ((Double) flattenedView2.getOrScalar(i)).doubleValue(), keanuRandom);
        }
        return DoubleTensor.create(dArr, jArr);
    }

    private static double sample(double d, double d2, KeanuRandom keanuRandom) {
        double exp;
        double d3;
        double d4;
        if (d <= 0.0d || d2 <= 0.0d) {
            throw new IllegalArgumentException("Invalid value for theta or k. Theta: " + d + ". k: " + d2);
        }
        double sqrt = 1.0d / Math.sqrt((2.0d * d2) - 1.0d);
        double log = d2 - Math.log(4.0d);
        double d5 = d2 + (1.0d / sqrt);
        double log2 = 1.0d + Math.log(4.5d);
        double d6 = 1.0d + (d2 / M_E);
        if (d2 < 1.0d) {
            return sampleWhileKLessThanOne(d6, d2, d, keanuRandom);
        }
        if (d2 == 1.0d) {
            return exponentialSample(d, keanuRandom);
        }
        do {
            double nextDouble = keanuRandom.nextDouble();
            double nextDouble2 = keanuRandom.nextDouble();
            double log3 = sqrt * Math.log(nextDouble / (1.0d - nextDouble));
            exp = d2 * Math.exp(log3);
            d3 = nextDouble * nextDouble * nextDouble2;
            d4 = (log + (d5 * log3)) - exp;
            if ((d4 + log2) - (4.5d * d3) >= 0.0d) {
                break;
            }
        } while (d4 < Math.log(d3));
        return d * exp;
    }

    private static double sampleWhileKLessThanOne(double d, double d2, double d3, KeanuRandom keanuRandom) {
        while (true) {
            double nextDouble = d * keanuRandom.nextDouble();
            if (nextDouble > 1.0d) {
                double d4 = -Math.log((d - nextDouble) / d2);
                if (keanuRandom.nextDouble() <= Math.pow(d4, d2 - 1.0d)) {
                    return d3 * d4;
                }
            } else {
                double pow = Math.pow(nextDouble, 1.0d / d2);
                if (keanuRandom.nextDouble() <= Math.exp(-pow)) {
                    return d3 * pow;
                }
            }
        }
    }

    private static double exponentialSample(double d, KeanuRandom keanuRandom) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid value for b");
        }
        return (-d) * Math.log(keanuRandom.nextDouble());
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(DoubleTensor doubleTensor) {
        DoubleTensor doubleTensor2 = (DoubleTensor) doubleTensor.div(this.theta);
        return (DoubleTensor) ((DoubleTensor) ((DoubleTensor) ((DoubleTensor) this.k.minus2(1.0d).timesInPlace(doubleTensor.log2())).minusInPlace(this.k.logGamma())).minusInPlace(doubleTensor2)).minusInPlace((DoubleTensor) this.k.times(this.theta.log2()));
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex doublePlaceholderVertex, DoublePlaceholderVertex doublePlaceholderVertex2, DoublePlaceholderVertex doublePlaceholderVertex3) {
        DivisionVertex div = doublePlaceholderVertex.div((DoubleVertex) doublePlaceholderVertex2);
        return doublePlaceholderVertex3.minus2(1.0d).times((DoubleVertex) doublePlaceholderVertex.log2()).minus((DoubleVertex) doublePlaceholderVertex3.logGamma()).minus((DoubleVertex) div).minus((DoubleVertex) doublePlaceholderVertex3.times((DoubleVertex) doublePlaceholderVertex2.log2()));
    }

    @Override // io.improbable.keanu.distributions.ContinuousDistribution
    public Diffs dLogProb(DoubleTensor doubleTensor) {
        DoubleTensor doubleTensor2 = (DoubleTensor) ((DoubleTensor) this.k.minus2(1.0d).divInPlace(doubleTensor)).minusInPlace(this.theta.reciprocal());
        DoubleTensor doubleTensor3 = (DoubleTensor) ((DoubleTensor) ((DoubleTensor) ((DoubleTensor) this.theta.times(this.k)).plusInPlace((DoubleTensor) doubleTensor.unaryMinus())).divInPlace(this.theta.pow2(2.0d))).unaryMinusInPlace();
        return new Diffs().put(Diffs.THETA, doubleTensor3).put(Diffs.K, (DoubleTensor) ((DoubleTensor) doubleTensor.log2().minusInPlace(this.theta.log2())).minusInPlace(this.k.digamma())).put(Diffs.X, doubleTensor2);
    }
}
