/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DivisionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogVertex;

public class Logistic
implements ContinuousDistribution {
    private final DoubleTensor mu;
    private final DoubleTensor s;

    public static ContinuousDistribution withParameters(DoubleTensor mu, DoubleTensor s) {
        return new Logistic(mu, s);
    }

    private Logistic(DoubleTensor mu, DoubleTensor s) {
        this.mu = mu;
        this.s = s;
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        return ((DoubleTensor)((DoubleTensor)((Object)((DoubleTensor)random.nextDouble(shape).reciprocalInPlace()).minusInPlace(1.0))).logInPlace()).timesInPlace(this.mu.minus(this.s));
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        DoubleTensor xMinusAOverB = x.minus(this.mu).divInPlace(this.s);
        DoubleTensor ln1OverB = (DoubleTensor)((DoubleTensor)this.s.reciprocal()).logInPlace();
        return (DoubleTensor)((Object)xMinusAOverB.plus(ln1OverB).minusInPlace(((DoubleTensor)((DoubleTensor)((Object)((DoubleTensor)xMinusAOverB.expInPlace()).plusInPlace(1.0))).logInPlace()).timesInPlace(2.0)));
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex mu, DoublePlaceholderVertex s) {
        DivisionVertex xMinusAOverB = x.minus(mu).div(s);
        LogVertex ln1OverB = s.reverseDiv(1.0).log();
        return xMinusAOverB.plus(ln1OverB).minus(xMinusAOverB.exp().plus(1.0).log().times(2.0));
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        DoubleTensor expAOverB = (DoubleTensor)this.mu.div(this.s).expInPlace();
        DoubleTensor expXOverB = (DoubleTensor)x.div(this.s).expInPlace();
        DoubleTensor expPlus = expAOverB.plus(expXOverB);
        DoubleTensor bTimesExpAOverB = expAOverB.times(this.s);
        DoubleTensor bTimesExpXOverB = expXOverB.times(this.s);
        DoubleTensor dLogPdmu = expXOverB.minus(expAOverB).divInPlace(this.s.times(expPlus));
        DoubleTensor dLogPdx = expAOverB.minus(expXOverB).divInPlace(bTimesExpAOverB.plus(bTimesExpXOverB));
        DoubleTensor numeratorPartOne = (DoubleTensor)this.mu.times(expXOverB).plusInPlace(x.times(expAOverB)).plusInPlace(this.mu.times(expAOverB.unaryMinus()));
        DoubleTensor numeratorPartTwo = bTimesExpAOverB.plus(bTimesExpXOverB).minusInPlace(x.times(expXOverB));
        DoubleTensor denominator = this.s.pow(2.0).timesInPlace(expPlus);
        DoubleTensor dLogPds = (DoubleTensor)numeratorPartOne.plus(numeratorPartTwo).divInPlace(denominator).unaryMinusInPlace();
        return new Diffs().put(Diffs.MU, dLogPdmu).put(Diffs.S, dLogPds).put(Diffs.X, dLogPdx);
    }
}

