/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.MultiplicationVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogGammaVertex;

public class InverseGamma
implements ContinuousDistribution {
    private final DoubleTensor alpha;
    private final DoubleTensor beta;

    public static ContinuousDistribution withParameters(DoubleTensor alpha, DoubleTensor beta) {
        return new InverseGamma(alpha, beta);
    }

    private InverseGamma(DoubleTensor alpha, DoubleTensor beta) {
        this.alpha = alpha;
        this.beta = beta;
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        DoubleTensor gammaSample = random.nextGamma(shape, (DoubleTensor)this.beta.reciprocal(), this.alpha);
        return (DoubleTensor)gammaSample.reciprocal();
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        DoubleTensor aTimesLnB = (DoubleTensor)this.alpha.times(this.beta.log());
        DoubleTensor negAMinus1TimesLnX = (DoubleTensor)((Object)((DoubleTensor)x.log()).timesInPlace(((DoubleTensor)this.alpha.unaryMinus()).minusInPlace(1.0)));
        DoubleTensor lnGammaA = (DoubleTensor)this.alpha.logGamma();
        return aTimesLnB.plus(negAMinus1TimesLnX).minusInPlace(lnGammaA).minusInPlace(this.beta.div(x));
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex alpha, DoublePlaceholderVertex beta) {
        MultiplicationVertex aTimesLnB = alpha.times(beta.log());
        MultiplicationVertex negAMinus1TimesLnX = x.log().times(alpha.unaryMinus().minus(1.0));
        LogGammaVertex lnGammaA = alpha.logGamma();
        return aTimesLnB.plus(negAMinus1TimesLnX).minus(lnGammaA).minus(beta.div(x));
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        DoubleTensor dPdalpha = (DoubleTensor)((DoubleTensor)((DoubleTensor)((DoubleTensor)x.log()).unaryMinusInPlace()).minusInPlace(this.alpha.digamma())).plusInPlace(this.beta.log());
        DoubleTensor dLogPdbeta = ((DoubleTensor)((DoubleTensor)x.reciprocal()).unaryMinusInPlace()).plusInPlace(this.alpha.div(this.beta));
        DoubleTensor dLogPdx = ((DoubleTensor)x.pow(2.0).reciprocalInPlace()).timesInPlace(((DoubleTensor)x.times(this.alpha.plus(1.0).unaryMinusInPlace())).plusInPlace(this.beta));
        return new Diffs().put(Diffs.A, dPdalpha).put(Diffs.B, dLogPdbeta).put(Diffs.X, dLogPdx);
    }
}

