/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DivisionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.MultiplicationVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogGammaVertex;

public class Gamma
implements ContinuousDistribution {
    private static final double M_E = 0.5772156649015329;
    private final DoubleTensor theta;
    private final DoubleTensor k;

    public static ContinuousDistribution withParameters(DoubleTensor theta, DoubleTensor k) {
        return new Gamma(theta, k);
    }

    private Gamma(DoubleTensor theta, DoubleTensor k) {
        this.theta = theta;
        this.k = k;
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        Tensor.FlattenedView thetaWrapped = this.theta.getFlattenedView();
        Tensor.FlattenedView kWrapped = this.k.getFlattenedView();
        int length = TensorShape.getLengthAsInt(shape);
        double[] samples = new double[length];
        for (int i = 0; i < length; ++i) {
            samples[i] = Gamma.sample((Double)thetaWrapped.getOrScalar(i), (Double)kWrapped.getOrScalar(i), random);
        }
        return DoubleTensor.create(samples, shape);
    }

    private static double sample(double theta, double k, KeanuRandom random) {
        double p2;
        double z;
        double y;
        double p1;
        double v;
        double w;
        if (theta <= 0.0 || k <= 0.0) {
            throw new IllegalArgumentException("Invalid value for theta or k. Theta: " + theta + ". k: " + k);
        }
        double A = 1.0 / Math.sqrt(2.0 * k - 1.0);
        double B = k - Math.log(4.0);
        double Q = k + 1.0 / A;
        double T = 4.5;
        double D = 1.0 + Math.log(4.5);
        double C = 1.0 + k / 0.5772156649015329;
        if (k < 1.0) {
            return Gamma.sampleWhileKLessThanOne(C, k, theta, random);
        }
        if (k == 1.0) {
            return Gamma.exponentialSample(theta, random);
        }
        do {
            p1 = random.nextDouble();
            p2 = random.nextDouble();
        } while (!((w = B + Q * (v = A * Math.log(p1 / (1.0 - p1))) - (y = k * Math.exp(v))) + D - 4.5 * (z = p1 * p1 * p2) >= 0.0) && !(w >= Math.log(z)));
        return theta * y;
    }

    private static double sampleWhileKLessThanOne(double c, double k, double theta, KeanuRandom random) {
        double y;
        while (true) {
            double p;
            if ((p = c * random.nextDouble()) > 1.0) {
                y = -Math.log((c - p) / k);
                if (!(random.nextDouble() <= Math.pow(y, k - 1.0))) continue;
                return theta * y;
            }
            y = Math.pow(p, 1.0 / k);
            if (random.nextDouble() <= Math.exp(-y)) break;
        }
        return theta * y;
    }

    private static double exponentialSample(double lambda, KeanuRandom random) {
        if (lambda <= 0.0) {
            throw new IllegalArgumentException("Invalid value for b");
        }
        return -lambda * Math.log(random.nextDouble());
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        DoubleTensor xOverTheta = x.div(this.theta);
        DoubleTensor kLnTheta = (DoubleTensor)this.k.times(this.theta.log());
        DoubleTensor kMinus1LogX = (DoubleTensor)this.k.minus(1.0).timesInPlace(x.log());
        DoubleTensor lgammaK = (DoubleTensor)this.k.logGamma();
        return kMinus1LogX.minusInPlace(lgammaK).minusInPlace(xOverTheta).minusInPlace(kLnTheta);
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex theta, DoublePlaceholderVertex k) {
        DivisionVertex xOverTheta = x.div(theta);
        MultiplicationVertex kLnTheta = k.times(theta.log());
        MultiplicationVertex kMinus1LogX = k.minus(1.0).times(x.log());
        LogGammaVertex lgammaK = k.logGamma();
        return kMinus1LogX.minus(lgammaK).minus(xOverTheta).minus(kLnTheta);
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        DoubleTensor dLogPdx = (DoubleTensor)this.k.minus(1.0).divInPlace(x).minusInPlace(this.theta.reciprocal());
        DoubleTensor dLogPdtheta = (DoubleTensor)((DoubleTensor)this.theta.times(this.k).plusInPlace(x.unaryMinus())).divInPlace(this.theta.pow(2.0)).unaryMinusInPlace();
        DoubleTensor dLogPdk = (DoubleTensor)((DoubleTensor)((DoubleTensor)x.log()).minusInPlace(this.theta.log())).minusInPlace(this.k.digamma());
        return new Diffs().put(Diffs.THETA, dLogPdtheta).put(Diffs.K, dLogPdk).put(Diffs.X, dLogPdx);
    }
}

