package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;

/* loaded from: input_file:io/improbable/keanu/distributions/continuous/Exponential.class */
public class Exponential implements ContinuousDistribution {
    private final DoubleTensor lambda;

    public static ContinuousDistribution withParameters(DoubleTensor doubleTensor) {
        return new Exponential(doubleTensor);
    }

    private Exponential(DoubleTensor doubleTensor) {
        this.lambda = doubleTensor;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        return (DoubleTensor) ((DoubleTensor) keanuRandom.nextDouble(jArr).logInPlace().timesInPlace(this.lambda)).unaryMinusInPlace();
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(DoubleTensor doubleTensor) {
        return (DoubleTensor) ((DoubleTensor) ((DoubleTensor) ((DoubleTensor) doubleTensor.unaryMinus()).divInPlace(this.lambda)).minusInPlace(this.lambda.log2())).setWithMask(doubleTensor.lessThanMask(DoubleTensor.scalar(0.0d)), Double.valueOf(Double.NEGATIVE_INFINITY));
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex doublePlaceholderVertex, DoublePlaceholderVertex doublePlaceholderVertex2) {
        return doublePlaceholderVertex.unaryMinus().div((DoubleVertex) doublePlaceholderVertex2).minus((DoubleVertex) doublePlaceholderVertex2.log2()).setWithMask(doublePlaceholderVertex.toLessThanMask(0.0d), Double.NEGATIVE_INFINITY);
    }

    @Override // io.improbable.keanu.distributions.ContinuousDistribution
    public Diffs dLogProb(DoubleTensor doubleTensor) {
        return new Diffs().put(Diffs.LAMBDA, (DoubleTensor) ((DoubleTensor) doubleTensor.minus(this.lambda)).divInPlace(this.lambda.pow2(2.0d))).put(Diffs.X, ((DoubleTensor) DoubleTensor.zeros(doubleTensor.getShape()).minusInPlace(this.lambda)).reciprocalInPlace());
    }
}
