/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.continuous.Gaussian;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DivisionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.PowerVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogVertex;

public class LogNormal
implements ContinuousDistribution {
    private final DoubleTensor mu;
    private final DoubleTensor sigma;

    public static ContinuousDistribution withParameters(DoubleTensor mu, DoubleTensor sigma) {
        return new LogNormal(mu, sigma);
    }

    private LogNormal(DoubleTensor mu, DoubleTensor sigma) {
        this.mu = mu;
        this.sigma = sigma;
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        return (DoubleTensor)((DoubleTensor)Gaussian.withParameters(this.mu, this.sigma).sample(shape, random)).expInPlace();
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        DoubleTensor lnSigmaX = (DoubleTensor)this.sigma.times(x).logInPlace();
        DoubleTensor lnXMinusMuSquared = (DoubleTensor)((DoubleTensor)x.log()).minusInPlace(this.mu).powInPlace(2.0);
        DoubleTensor lnXMinusMuSquaredOver2Variance = (DoubleTensor)((Object)lnXMinusMuSquared.divInPlace(this.sigma.pow(2.0).timesInPlace(2.0)));
        return (DoubleTensor)((DoubleTensor)((Object)lnXMinusMuSquaredOver2Variance.plusInPlace(lnSigmaX).plusInPlace(Gaussian.LN_SQRT_2PI))).unaryMinusInPlace();
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex mu, DoublePlaceholderVertex sigma) {
        LogVertex lnSigmaX = sigma.times(x).log();
        PowerVertex lnXMinusMuSquared = x.log().minus(mu).pow(2.0);
        DivisionVertex lnXMinusMuSquaredOver2Variance = lnXMinusMuSquared.div(sigma.pow(2.0).times(2.0));
        return lnXMinusMuSquaredOver2Variance.plus(lnSigmaX).plus(Gaussian.LN_SQRT_2PI).unaryMinus();
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        DoubleTensor variance = this.sigma.pow(2.0);
        DoubleTensor lnXMinusMu = ((DoubleTensor)x.log()).minusInPlace(this.mu);
        DoubleTensor dLogPdmu = lnXMinusMu.div(variance);
        DoubleTensor dLogPdx = ((DoubleTensor)dLogPdmu.plus(1.0).unaryMinus()).divInPlace(x);
        DoubleTensor dLogPdsigma = (DoubleTensor)((DoubleTensor)lnXMinusMu.powInPlace(2.0)).divInPlace(variance.timesInPlace(this.sigma)).minusInPlace(this.sigma.reciprocal());
        return new Diffs().put(Diffs.MU, dLogPdmu).put(Diffs.SIGMA, dLogPdsigma).put(Diffs.X, dLogPdx);
    }
}

