package io.improbable.keanu.distributions.discrete;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.DiscreteDistribution;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import java.util.Arrays;

/* loaded from: input_file:io/improbable/keanu/distributions/discrete/Multinomial.class */
public class Multinomial implements DiscreteDistribution {
    public static final double DEFAULT_ALLOWED_PROBABILITY_ERROR = 0.001d;
    private static double allowedProbabilityError = 0.001d;
    private final IntegerTensor n;
    private final DoubleTensor p;
    private final int k;
    private boolean validationEnabled;

    public static void setAllowedProbabilityError(double d) {
        allowedProbabilityError = d;
    }

    public static Multinomial withParameters(IntegerTensor integerTensor, DoubleTensor doubleTensor) {
        return new Multinomial(integerTensor, doubleTensor, false);
    }

    public static Multinomial withParameters(IntegerTensor integerTensor, DoubleTensor doubleTensor, boolean z) {
        return new Multinomial(integerTensor, doubleTensor, z);
    }

    private Multinomial(IntegerTensor integerTensor, DoubleTensor doubleTensor, boolean z) {
        this.k = Ints.checkedCast(doubleTensor.getShape()[doubleTensor.getRank() - 1]);
        this.n = integerTensor;
        this.p = doubleTensor;
        this.validationEnabled = z;
        if (z) {
            validateProbabilities(doubleTensor);
            validateN(integerTensor);
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.distributions.Distribution
    public IntegerTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        if (this.validationEnabled) {
            validateBroadcastShapes(jArr, this.n.getShape(), this.p.getShape());
            validateXShape(jArr, this.p.getShape());
        }
        IntegerTensor integerTensor = (IntegerTensor) this.n.broadcast(TensorShape.selectDimensions(0, jArr.length - 1, jArr));
        double[] asFlatDoubleArray = ((DoubleTensor) this.p.broadcast(TensorShape.getBroadcastResultShape(TensorShape.concat(integerTensor.getShape(), new long[]{1}), this.p.getShape()))).asFlatDoubleArray();
        int[] asFlatIntegerArray = integerTensor.asFlatIntegerArray();
        int length = asFlatIntegerArray.length;
        int[] iArr = new int[this.k * length];
        for (int i = 0; i < length; i++) {
            int i2 = i * this.k;
            drawNTimes(asFlatIntegerArray[i], keanuRandom, iArr, i2, asFlatDoubleArray, i2, this.k);
        }
        return IntegerTensor.create(iArr, jArr);
    }

    private static void drawNTimes(int i, KeanuRandom keanuRandom, int[] iArr, int i2, double[] dArr, int i3, int i4) {
        for (int i5 = 0; i5 < i; i5++) {
            int draw = i2 + draw(keanuRandom, dArr, i3, i4);
            iArr[draw] = iArr[draw] + 1;
        }
    }

    private static int draw(KeanuRandom keanuRandom, double[] dArr, int i, int i2) {
        double nextDouble = keanuRandom.nextDouble();
        int i3 = 0;
        double d = 0.0d;
        while (i3 < i2) {
            double d2 = dArr[i + i3];
            i3++;
            if (d2 != 0.0d) {
                d += d2;
                if (d >= nextDouble) {
                    break;
                }
            }
        }
        return i3 - 1;
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(IntegerTensor integerTensor) {
        if (this.validationEnabled) {
            validateBroadcastShapes(integerTensor.getShape(), this.n.getShape(), this.p.getShape());
            validateX(integerTensor, this.n, this.p);
        }
        DoubleTensor logGammaInPlace = this.n.plus(1).toDouble().logGammaInPlace();
        return (DoubleTensor) ((DoubleTensor) ((DoubleTensor) ((DoubleTensor) this.p.log2().timesInPlace(integerTensor.toDouble())).sum(-1)).plusInPlace(logGammaInPlace)).minusInPlace((DoubleTensor) integerTensor.plus(1).toDouble().logGammaInPlace().sum(-1));
    }

    private static void validateProbabilities(DoubleTensor doubleTensor) {
        if (doubleTensor.isScalar()) {
            throw new IllegalArgumentException("Probabilities must be a vector or a tensor with rank >= 1");
        }
        if (!(doubleTensor.greaterThan((DoubleTensor) Double.valueOf(0.0d)).allTrue() && doubleTensor.lessThan((DoubleTensor) Double.valueOf(1.0d)).allTrue())) {
            throw new IllegalArgumentException("Probabilities must be > 0 < 1 but were " + Arrays.toString(doubleTensor.asFlatDoubleArray()));
        }
        DoubleTensor doubleTensor2 = (DoubleTensor) doubleTensor.sum(-1);
        if (!doubleTensor2.equalsWithinEpsilon(DoubleTensor.create(1.0d, doubleTensor2.getShape()), Double.valueOf(allowedProbabilityError))) {
            throw new IllegalArgumentException("Probabilities must sum to 1 but summed to " + Arrays.toString(doubleTensor2.asFlatDoubleArray()));
        }
    }

    private static void validateN(IntegerTensor integerTensor) {
        if (!integerTensor.greaterThanOrEqual((IntegerTensor) 0).allTrue()) {
            throw new IllegalArgumentException("Number of trials (n) must be non-negative.");
        }
    }

    /* JADX WARN: Type inference failed for: r2v9, types: [long[], long[][]] */
    private static void validateX(IntegerTensor integerTensor, IntegerTensor integerTensor2, DoubleTensor doubleTensor) {
        if (!(integerTensor.greaterThanOrEqual((IntegerTensor) 0).allTrue() && integerTensor.lessThanOrEqual((IntegerTensor) integerTensor2.reshape(Longs.concat((long[][]) new long[]{integerTensor2.getShape(), new long[]{1}}))).allTrue())) {
            throw new IllegalArgumentException("x must be >= 0 and <= n");
        }
        validateXShape(integerTensor.getShape(), doubleTensor.getShape());
        IntegerTensor integerTensor3 = (IntegerTensor) integerTensor.sum(-1);
        if (!integerTensor3.elementwiseEquals((Tensor) integerTensor2).allTrue()) {
            throw new IllegalArgumentException("The sum of x " + Arrays.toString(integerTensor3.asFlatArray()) + " must equal n " + Arrays.toString(integerTensor2.asFlatDoubleArray()));
        }
    }

    private static void validateXShape(long[] jArr, long[] jArr2) {
        long j = jArr2[jArr2.length - 1];
        long j2 = jArr.length == 0 ? 0L : jArr[jArr.length - 1];
        Preconditions.checkArgument(j2 == j, "x shape must have far right dimension matching number of categories k " + j + " but had " + j2 + " categories.");
    }

    private static void validateBroadcastShapes(long[] jArr, long[] jArr2, long[] jArr3) {
        try {
            long[] broadcastResultShape = TensorShape.getBroadcastResultShape(jArr3, TensorShape.concat(jArr2, new long[]{1}));
            if (!TensorShapeValidation.isBroadcastable(broadcastResultShape, jArr)) {
                throw new IllegalArgumentException("Shape " + Arrays.toString(jArr) + " is incompatible with n shape " + Arrays.toString(jArr2) + " and p shape " + Arrays.toString(jArr3) + ". It must be broadcastable with " + Arrays.toString(broadcastResultShape));
            }
        } catch (IllegalArgumentException e) {
            throw new IllegalArgumentException("p shape " + Arrays.toString(jArr3) + " incompatible with n shape " + Arrays.toString(jArr2), e);
        }
    }
}
