package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;

/* loaded from: input_file:io/improbable/keanu/distributions/continuous/Laplace.class */
public class Laplace implements ContinuousDistribution {
    private final DoubleTensor mu;
    private final DoubleTensor beta;

    public static ContinuousDistribution withParameters(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        return new Laplace(doubleTensor, doubleTensor2);
    }

    private Laplace(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        this.mu = doubleTensor;
        this.beta = doubleTensor2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        Tensor.FlattenedView<N> flattenedView = this.mu.getFlattenedView();
        Tensor.FlattenedView<N> flattenedView2 = this.beta.getFlattenedView();
        int lengthAsInt = TensorShape.getLengthAsInt(jArr);
        double[] dArr = new double[lengthAsInt];
        for (int i = 0; i < lengthAsInt; i++) {
            dArr[i] = sample(((Double) flattenedView.getOrScalar(i)).doubleValue(), ((Double) flattenedView2.getOrScalar(i)).doubleValue(), keanuRandom);
        }
        return DoubleTensor.create(dArr, jArr);
    }

    private static double sample(double d, double d2, KeanuRandom keanuRandom) {
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Invalid value for beta: " + d2);
        }
        return keanuRandom.nextDouble() > 0.5d ? d + (d2 * Math.log(keanuRandom.nextDouble())) : d - (d2 * Math.log(keanuRandom.nextDouble()));
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(DoubleTensor doubleTensor) {
        return (DoubleTensor) ((DoubleTensor) ((DoubleTensor) ((DoubleTensor) ((DoubleTensor) this.mu.minus(doubleTensor)).abs()).divInPlace(this.beta)).plusInPlace(this.beta.times2(2.0d).logInPlace())).unaryMinus();
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex doublePlaceholderVertex, DoublePlaceholderVertex doublePlaceholderVertex2, DoublePlaceholderVertex doublePlaceholderVertex3) {
        return doublePlaceholderVertex2.minus((DoubleVertex) doublePlaceholderVertex).abs().div((DoubleVertex) doublePlaceholderVertex3).plus((DoubleVertex) doublePlaceholderVertex3.times2(2.0d).log2()).unaryMinus();
    }

    @Override // io.improbable.keanu.distributions.ContinuousDistribution
    public Diffs dLogProb(DoubleTensor doubleTensor) {
        DoubleTensor doubleTensor2 = (DoubleTensor) this.mu.minus(doubleTensor);
        DoubleTensor doubleTensor3 = (DoubleTensor) doubleTensor2.abs();
        DoubleTensor doubleTensor4 = (DoubleTensor) doubleTensor3.times(this.beta);
        DoubleTensor doubleTensor5 = (DoubleTensor) doubleTensor2.divInPlace(doubleTensor4);
        DoubleTensor doubleTensor6 = (DoubleTensor) ((DoubleTensor) doubleTensor.minus(this.mu)).divInPlace(doubleTensor4);
        return new Diffs().put(Diffs.MU, doubleTensor6).put(Diffs.BETA, (DoubleTensor) ((DoubleTensor) doubleTensor3.minusInPlace(this.beta)).divInPlace(this.beta.pow2(2.0d))).put(Diffs.X, doubleTensor5);
    }
}
