package io.improbable.keanu.distributions.discrete;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.DiscreteDistribution;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.intgr.IntegerPlaceholderVertex;

/* loaded from: input_file:io/improbable/keanu/distributions/discrete/UniformInt.class */
public class UniformInt implements DiscreteDistribution {
    private final IntegerTensor xMin;
    private final IntegerTensor xMax;

    public static DiscreteDistribution withParameters(IntegerTensor integerTensor, IntegerTensor integerTensor2) {
        return new UniformInt(integerTensor, integerTensor2);
    }

    private UniformInt(IntegerTensor integerTensor, IntegerTensor integerTensor2) {
        this.xMin = integerTensor;
        this.xMax = integerTensor2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.distributions.Distribution
    public IntegerTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        DoubleTensor doubleTensor = this.xMin.toDouble();
        return ((DoubleTensor) ((DoubleTensor) ((DoubleTensor) this.xMax.toDouble().minus(doubleTensor)).timesInPlace(keanuRandom.nextDouble(jArr))).plusInPlace(doubleTensor)).toInteger();
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(IntegerTensor integerTensor) {
        DoubleTensor doubleTensor = this.xMax.toDouble();
        DoubleTensor doubleTensor2 = this.xMin.toDouble();
        DoubleTensor doubleTensor3 = integerTensor.toDouble();
        return (DoubleTensor) ((DoubleTensor) ((DoubleTensor) ((DoubleTensor) doubleTensor.minus(doubleTensor2)).logInPlace().unaryMinusInPlace()).setWithMaskInPlace(doubleTensor3.greaterThanOrEqualToMask(doubleTensor), Double.valueOf(Double.NEGATIVE_INFINITY))).setWithMaskInPlace(doubleTensor3.lessThanMask(doubleTensor2), Double.valueOf(Double.NEGATIVE_INFINITY));
    }

    public static DoubleVertex logProbOutput(IntegerPlaceholderVertex integerPlaceholderVertex, IntegerPlaceholderVertex integerPlaceholderVertex2, IntegerPlaceholderVertex integerPlaceholderVertex3) {
        DoubleVertex doubleVertex = integerPlaceholderVertex3.toDouble();
        DoubleVertex doubleVertex2 = integerPlaceholderVertex2.toDouble();
        DoubleVertex doubleVertex3 = integerPlaceholderVertex.toDouble();
        return doubleVertex.minus(doubleVertex2).log2().unaryMinus().setWithMask(doubleVertex3.toGreaterThanOrEqualToMask(doubleVertex), Double.NEGATIVE_INFINITY).setWithMask(doubleVertex3.toLessThanMask(doubleVertex2), Double.NEGATIVE_INFINITY);
    }
}
