/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DivisionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogVertex;

public class Laplace
implements ContinuousDistribution {
    private final DoubleTensor mu;
    private final DoubleTensor beta;

    public static ContinuousDistribution withParameters(DoubleTensor mu, DoubleTensor beta) {
        return new Laplace(mu, beta);
    }

    private Laplace(DoubleTensor mu, DoubleTensor beta) {
        this.mu = mu;
        this.beta = beta;
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        Tensor.FlattenedView muWrapped = this.mu.getFlattenedView();
        Tensor.FlattenedView betaWrapped = this.beta.getFlattenedView();
        int length = TensorShape.getLengthAsInt(shape);
        double[] samples = new double[length];
        for (int i = 0; i < length; ++i) {
            samples[i] = Laplace.sample((Double)muWrapped.getOrScalar(i), (Double)betaWrapped.getOrScalar(i), random);
        }
        return DoubleTensor.create(samples, shape);
    }

    private static double sample(double mu, double beta, KeanuRandom random) {
        if (beta <= 0.0) {
            throw new IllegalArgumentException("Invalid value for beta: " + beta);
        }
        if (random.nextDouble() > 0.5) {
            return mu + beta * Math.log(random.nextDouble());
        }
        return mu - beta * Math.log(random.nextDouble());
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        DoubleTensor muMinusXAbsNegDivBeta = ((DoubleTensor)this.mu.minus(x).abs()).divInPlace(this.beta);
        DoubleTensor logTwoBeta = (DoubleTensor)this.beta.times(2.0).logInPlace();
        return (DoubleTensor)muMinusXAbsNegDivBeta.plusInPlace(logTwoBeta).unaryMinus();
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex mu, DoublePlaceholderVertex beta) {
        DivisionVertex muMinusXAbsNegDivBeta = mu.minus(x).abs().div(beta);
        LogVertex logTwoBeta = beta.times(2.0).log();
        return muMinusXAbsNegDivBeta.plus(logTwoBeta).unaryMinus();
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        DoubleTensor muMinusX = this.mu.minus(x);
        DoubleTensor muMinusXAbs = (DoubleTensor)muMinusX.abs();
        DoubleTensor denominator = muMinusXAbs.times(this.beta);
        DoubleTensor dLogPdx = muMinusX.divInPlace(denominator);
        DoubleTensor dLogPdMu = x.minus(this.mu).divInPlace(denominator);
        DoubleTensor dLogPdBeta = muMinusXAbs.minusInPlace(this.beta).divInPlace(this.beta.pow(2.0));
        return new Diffs().put(Diffs.MU, dLogPdMu).put(Diffs.BETA, dLogPdBeta).put(Diffs.X, dLogPdx);
    }
}

