/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.discrete;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.DiscreteDistribution;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogGammaVertex;
import io.improbable.keanu.vertices.intgr.IntegerPlaceholderVertex;

public class Poisson
implements DiscreteDistribution {
    private final DoubleTensor mu;

    public static DiscreteDistribution withParameters(DoubleTensor mu) {
        return new Poisson(mu);
    }

    private Poisson(DoubleTensor mu) {
        this.mu = mu;
    }

    @Override
    public IntegerTensor sample(long[] shape, KeanuRandom random) {
        Tensor.FlattenedView muWrapped = this.mu.getFlattenedView();
        int length = TensorShape.getLengthAsInt(shape);
        int[] samples = new int[length];
        for (int i = 0; i < length; ++i) {
            samples[i] = Poisson.sample((Double)muWrapped.getOrScalar(i), random);
        }
        return IntegerTensor.create(samples, shape);
    }

    private static int sample(double mu, KeanuRandom random) {
        if (mu <= 0.0) {
            throw new IllegalArgumentException("Invalid value for mu: " + mu);
        }
        double STEP_IN_MU = 500.0;
        double muLeft = mu;
        int k = 0;
        double p = 1.0;
        do {
            ++k;
            double u = random.nextDoubleNonZero();
            p *= u;
            while (p < 1.0 && muLeft > 0.0) {
                if (muLeft > 500.0) {
                    p *= Math.exp(500.0);
                    muLeft -= 500.0;
                    continue;
                }
                p *= Math.exp(muLeft);
                muLeft = 0.0;
            }
        } while (p > 1.0);
        return k - 1;
    }

    @Override
    public DoubleTensor logProb(IntegerTensor k) {
        DoubleTensor kDouble = k.toDouble();
        DoubleTensor logFactorialK = (DoubleTensor)kDouble.plus(1.0).logGammaInPlace();
        return ((DoubleTensor)kDouble.timesInPlace(this.mu.log())).minusInPlace(this.mu).minusInPlace(logFactorialK);
    }

    public static DoubleVertex logProbOutput(IntegerPlaceholderVertex k, DoublePlaceholderVertex mu) {
        DoubleVertex kDouble = k.toDouble();
        LogGammaVertex logFactorialK = kDouble.plus(1.0).logGamma();
        return kDouble.times(mu.log()).minus(mu).minus(logFactorialK);
    }
}

