/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.discrete;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.DiscreteDistribution;
import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import java.util.Arrays;

public class Multinomial
implements DiscreteDistribution {
    public static final double DEFAULT_ALLOWED_PROBABILITY_ERROR = 0.001;
    private static double allowedProbabilityError = 0.001;
    private final IntegerTensor n;
    private final DoubleTensor p;
    private final int k;
    private boolean validationEnabled;

    public static void setAllowedProbabilityError(double allowedProbabilityError) {
        Multinomial.allowedProbabilityError = allowedProbabilityError;
    }

    public static Multinomial withParameters(IntegerTensor n, DoubleTensor p) {
        return new Multinomial(n, p, false);
    }

    public static Multinomial withParameters(IntegerTensor n, DoubleTensor p, boolean validationEnabled) {
        return new Multinomial(n, p, validationEnabled);
    }

    private Multinomial(IntegerTensor n, DoubleTensor p, boolean validationEnabled) {
        this.k = Ints.checkedCast((long)p.getShape()[p.getRank() - 1]);
        this.n = n;
        this.p = p;
        this.validationEnabled = validationEnabled;
        if (validationEnabled) {
            Multinomial.validateProbabilities(p);
            Multinomial.validateN(n);
        }
    }

    @Override
    public IntegerTensor sample(long[] shape, KeanuRandom random) {
        if (this.validationEnabled) {
            Multinomial.validateBroadcastShapes(shape, this.n.getShape(), this.p.getShape());
            Multinomial.validateXShape(shape, this.p.getShape());
        }
        long[] sampleBatchShape = TensorShape.selectDimensions(0, shape.length - 1, shape);
        IntegerTensor broadcastedN = (IntegerTensor)this.n.broadcast(sampleBatchShape);
        long[] broadcastResultShape = TensorShape.getBroadcastResultShape(TensorShape.concat(broadcastedN.getShape(), new long[]{1L}), this.p.getShape());
        DoubleTensor broadcastedP = (DoubleTensor)this.p.broadcast(broadcastResultShape);
        double[] flatP = broadcastedP.asFlatDoubleArray();
        int[] flatN = broadcastedN.asFlatIntegerArray();
        int sampleCount = flatN.length;
        int[] samples = new int[this.k * sampleCount];
        for (int i = 0; i < sampleCount; ++i) {
            int positionByK = i * this.k;
            Multinomial.drawNTimes(flatN[i], random, samples, positionByK, flatP, positionByK, this.k);
        }
        return IntegerTensor.create(samples, shape);
    }

    private static void drawNTimes(int n, KeanuRandom random, int[] sample, int sampleIndex, double[] p, int pIndex, int pCount) {
        for (int i = 0; i < n; ++i) {
            int drawnIndex = Multinomial.draw(random, p, pIndex, pCount);
            int n2 = sampleIndex + drawnIndex;
            sample[n2] = sample[n2] + 1;
        }
    }

    private static int draw(KeanuRandom random, double[] p, int pIndex, int pCount) {
        int index;
        double value = random.nextDouble();
        double pCumulative = 0.0;
        for (index = 0; index < pCount; ++index) {
            double currentP = p[pIndex + index];
            if (currentP == 0.0 || !((pCumulative += currentP) >= value)) continue;
            break;
        }
        return index - 1;
    }

    @Override
    public DoubleTensor logProb(IntegerTensor x) {
        if (this.validationEnabled) {
            Multinomial.validateBroadcastShapes(x.getShape(), this.n.getShape(), this.p.getShape());
            Multinomial.validateX(x, this.n, this.p);
        }
        DoubleTensor gammaN = (DoubleTensor)this.n.plus(1).toDouble().logGammaInPlace();
        DoubleTensor xLogP = (DoubleTensor)((DoubleTensor)this.p.log()).timesInPlace(x.toDouble()).sum(-1);
        DoubleTensor gammaXs = (DoubleTensor)((DoubleTensor)x.plus(1).toDouble().logGammaInPlace()).sum(-1);
        return xLogP.plusInPlace(gammaN).minusInPlace(gammaXs);
    }

    private static void validateProbabilities(DoubleTensor p) {
        boolean pRangeValidated;
        if (p.isScalar()) {
            throw new IllegalArgumentException("Probabilities must be a vector or a tensor with rank >= 1");
        }
        boolean bl = pRangeValidated = p.greaterThan(0.0).allTrue() && p.lessThan(1.0).allTrue();
        if (!pRangeValidated) {
            throw new IllegalArgumentException("Probabilities must be > 0 < 1 but were " + Arrays.toString(p.asFlatDoubleArray()));
        }
        DoubleTensor pSum = (DoubleTensor)p.sum(-1);
        boolean pSumValidated = pSum.equalsWithinEpsilon(DoubleTensor.create(1.0, pSum.getShape()), allowedProbabilityError);
        if (!pSumValidated) {
            throw new IllegalArgumentException("Probabilities must sum to 1 but summed to " + Arrays.toString(pSum.asFlatDoubleArray()));
        }
    }

    private static void validateN(IntegerTensor n) {
        boolean nRangeValidated = n.greaterThanOrEqual(0).allTrue();
        if (!nRangeValidated) {
            throw new IllegalArgumentException("Number of trials (n) must be non-negative.");
        }
    }

    private static void validateX(IntegerTensor x, IntegerTensor n, DoubleTensor p) {
        boolean xRangeValidated;
        boolean bl = xRangeValidated = x.greaterThanOrEqual(0).allTrue() && x.lessThanOrEqual((NumberTensor)n.reshape(Longs.concat((long[][])new long[][]{n.getShape(), {1L}}))).allTrue();
        if (!xRangeValidated) {
            throw new IllegalArgumentException("x must be >= 0 and <= n");
        }
        Multinomial.validateXShape(x.getShape(), p.getShape());
        IntegerTensor xSum = (IntegerTensor)x.sum(-1);
        boolean xSumValidated = xSum.elementwiseEquals(n).allTrue();
        if (!xSumValidated) {
            throw new IllegalArgumentException("The sum of x " + Arrays.toString(xSum.asFlatArray()) + " must equal n " + Arrays.toString(n.asFlatDoubleArray()));
        }
    }

    private static void validateXShape(long[] xShape, long[] pShape) {
        long kAccordingToP = pShape[pShape.length - 1];
        long kAccordingToX = xShape.length == 0 ? 0L : xShape[xShape.length - 1];
        Preconditions.checkArgument((kAccordingToX == kAccordingToP ? 1 : 0) != 0, (Object)("x shape must have far right dimension matching number of categories k " + kAccordingToP + " but had " + kAccordingToX + " categories."));
    }

    private static void validateBroadcastShapes(long[] xShape, long[] nShape, long[] pShape) {
        long[] broadcastResultShape;
        try {
            broadcastResultShape = TensorShape.getBroadcastResultShape(pShape, TensorShape.concat(nShape, new long[]{1L}));
        }
        catch (IllegalArgumentException e) {
            throw new IllegalArgumentException("p shape " + Arrays.toString(pShape) + " incompatible with n shape " + Arrays.toString(nShape), e);
        }
        if (!TensorShapeValidation.isBroadcastable(broadcastResultShape, xShape)) {
            throw new IllegalArgumentException("Shape " + Arrays.toString(xShape) + " is incompatible with n shape " + Arrays.toString(nShape) + " and p shape " + Arrays.toString(pShape) + ". It must be broadcastable with " + Arrays.toString(broadcastResultShape));
        }
    }
}

