/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DifferenceVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DivisionVertex;

public class Exponential
implements ContinuousDistribution {
    private final DoubleTensor lambda;

    public static ContinuousDistribution withParameters(DoubleTensor lambda) {
        return new Exponential(lambda);
    }

    private Exponential(DoubleTensor lambda) {
        this.lambda = lambda;
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        return (DoubleTensor)((DoubleTensor)random.nextDouble(shape).logInPlace()).timesInPlace(this.lambda).unaryMinusInPlace();
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        DoubleTensor negXMinusADivB = ((DoubleTensor)x.unaryMinus()).divInPlace(this.lambda);
        DoubleTensor negXMinusADivBMinusLogB = (DoubleTensor)negXMinusADivB.minusInPlace(this.lambda.log());
        return negXMinusADivBMinusLogB.setWithMask(x.lessThanMask(DoubleTensor.scalar(0.0)), Double.NEGATIVE_INFINITY);
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex lambda) {
        DivisionVertex negXMinusADivB = x.unaryMinus().div(lambda);
        DifferenceVertex negXMinusADivBMinusLogB = negXMinusADivB.minus(lambda.log());
        return negXMinusADivBMinusLogB.setWithMask(x.toLessThanMask(0.0), Double.NEGATIVE_INFINITY);
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        DoubleTensor dLogPdx = (DoubleTensor)DoubleTensor.zeros(x.getShape()).minusInPlace(this.lambda).reciprocalInPlace();
        DoubleTensor dLogPdlambda = x.minus(this.lambda).divInPlace(this.lambda.pow(2.0));
        return new Diffs().put(Diffs.LAMBDA, dLogPdlambda).put(Diffs.X, dLogPdx);
    }
}

