/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.discrete;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.DiscreteDistribution;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.MultiplicationVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogGammaVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogVertex;
import io.improbable.keanu.vertices.intgr.IntegerPlaceholderVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;

public class Binomial
implements DiscreteDistribution {
    private final DoubleTensor p;
    private final IntegerTensor n;

    public static DiscreteDistribution withParameters(DoubleTensor p, IntegerTensor n) {
        return new Binomial(p, n);
    }

    private Binomial(DoubleTensor p, IntegerTensor n) {
        this.p = p;
        this.n = n;
    }

    @Override
    public IntegerTensor sample(long[] shape, KeanuRandom random) {
        Tensor.FlattenedView pWrapped = this.p.getFlattenedView();
        Tensor.FlattenedView nWrapped = this.n.getFlattenedView();
        int length = TensorShape.getLengthAsInt(shape);
        int[] samples = new int[length];
        for (int i = 0; i < length; ++i) {
            samples[i] = Binomial.sample((Double)pWrapped.getOrScalar(i), (Integer)nWrapped.getOrScalar(i), random);
        }
        return IntegerTensor.create(samples, shape);
    }

    private static int sample(double p, int n, KeanuRandom random) {
        int sum = 0;
        for (int i = 0; i < n; ++i) {
            if (!(random.nextDouble() < p)) continue;
            ++sum;
        }
        return sum;
    }

    @Override
    public DoubleTensor logProb(IntegerTensor k) {
        DoubleTensor logBinomialCoefficient = Binomial.getLogBinomialCoefficient(k, this.n);
        DoubleTensor kDouble = k.toDouble();
        DoubleTensor nDouble = this.n.toDouble();
        DoubleTensor kLogP = (DoubleTensor)kDouble.times(this.p.log());
        DoubleTensor logOneMinusP = (DoubleTensor)((DoubleTensor)((Object)((DoubleTensor)this.p.unaryMinus()).plusInPlace(1.0))).logInPlace();
        DoubleTensor nMinusKLogOneMinusP = nDouble.minusInPlace(kDouble).timesInPlace(logOneMinusP);
        return logBinomialCoefficient.plusInPlace(kLogP).plusInPlace(nMinusKLogOneMinusP);
    }

    public static DoubleVertex logProbOutput(IntegerPlaceholderVertex k, DoublePlaceholderVertex p, IntegerPlaceholderVertex n) {
        DoubleVertex logBinomialCoefficient = Binomial.getLogBinomialCoefficient(k, n);
        DoubleVertex kDouble = k.toDouble();
        DoubleVertex nDouble = n.toDouble();
        MultiplicationVertex kLogP = kDouble.times(p.log());
        LogVertex logOneMinusP = p.unaryMinus().plus(1.0).log();
        MultiplicationVertex nMinusKLogOneMinusP = nDouble.minus(kDouble).times(logOneMinusP);
        return logBinomialCoefficient.plus(kLogP).plus(nMinusKLogOneMinusP);
    }

    private static DoubleTensor getLogBinomialCoefficient(IntegerTensor k, IntegerTensor n) {
        DoubleTensor nDouble = n.toDouble();
        DoubleTensor kDouble = k.toDouble();
        DoubleTensor logNFactorial = (DoubleTensor)nDouble.plus(1.0).logGammaInPlace();
        DoubleTensor logKFactorial = (DoubleTensor)kDouble.plus(1.0).logGammaInPlace();
        DoubleTensor logNMinusKFactorial = (DoubleTensor)((DoubleTensor)((Object)nDouble.minusInPlace(kDouble).plusInPlace(1.0))).logGammaInPlace();
        return logNFactorial.minusInPlace(logKFactorial).minusInPlace(logNMinusKFactorial);
    }

    private static DoubleVertex getLogBinomialCoefficient(IntegerVertex k, IntegerVertex n) {
        DoubleVertex nDouble = n.toDouble();
        DoubleVertex kDouble = k.toDouble();
        LogGammaVertex logNFactorial = nDouble.plus(1.0).logGamma();
        LogGammaVertex logKFactorial = kDouble.plus(1.0).logGamma();
        LogGammaVertex logNMinusKFactorial = nDouble.minus(kDouble).plus(1.0).logGamma();
        return logNFactorial.minus(logKFactorial).minus(logNMinusKFactorial);
    }
}

