package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.MultiplicationVertex;

/* loaded from: input_file:io/improbable/keanu/distributions/continuous/InverseGamma.class */
public class InverseGamma implements ContinuousDistribution {
    private final DoubleTensor alpha;
    private final DoubleTensor beta;

    public static ContinuousDistribution withParameters(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        return new InverseGamma(doubleTensor, doubleTensor2);
    }

    private InverseGamma(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        this.alpha = doubleTensor;
        this.beta = doubleTensor2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        return keanuRandom.nextGamma(jArr, this.beta.reciprocal(), this.alpha).reciprocal();
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(DoubleTensor doubleTensor) {
        DoubleTensor doubleTensor2 = (DoubleTensor) this.alpha.times(this.beta.log2());
        DoubleTensor doubleTensor3 = (DoubleTensor) doubleTensor.log2().timesInPlace((DoubleTensor) ((DoubleTensor) this.alpha.unaryMinus()).minusInPlace((DoubleTensor) Double.valueOf(1.0d)));
        return (DoubleTensor) ((DoubleTensor) ((DoubleTensor) doubleTensor2.plus(doubleTensor3)).minusInPlace(this.alpha.logGamma())).minusInPlace((DoubleTensor) this.beta.div(doubleTensor));
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex doublePlaceholderVertex, DoublePlaceholderVertex doublePlaceholderVertex2, DoublePlaceholderVertex doublePlaceholderVertex3) {
        MultiplicationVertex times = doublePlaceholderVertex2.times((DoubleVertex) doublePlaceholderVertex3.log2());
        MultiplicationVertex times2 = doublePlaceholderVertex.log2().times((DoubleVertex) doublePlaceholderVertex2.unaryMinus().minus2(1.0d));
        return times.plus((DoubleVertex) times2).minus((DoubleVertex) doublePlaceholderVertex2.logGamma()).minus((DoubleVertex) doublePlaceholderVertex3.div((DoubleVertex) doublePlaceholderVertex));
    }

    @Override // io.improbable.keanu.distributions.ContinuousDistribution
    public Diffs dLogProb(DoubleTensor doubleTensor) {
        DoubleTensor doubleTensor2 = (DoubleTensor) ((DoubleTensor) ((DoubleTensor) doubleTensor.log2().unaryMinusInPlace()).minusInPlace(this.alpha.digamma())).plusInPlace(this.beta.log2());
        DoubleTensor doubleTensor3 = (DoubleTensor) ((DoubleTensor) doubleTensor.reciprocal().unaryMinusInPlace()).plusInPlace((DoubleTensor) this.alpha.div(this.beta));
        return new Diffs().put(Diffs.A, doubleTensor2).put(Diffs.B, doubleTensor3).put(Diffs.X, (DoubleTensor) doubleTensor.pow2(2.0d).reciprocalInPlace().timesInPlace((DoubleTensor) ((DoubleTensor) doubleTensor.times((DoubleTensor) this.alpha.plus2(1.0d).unaryMinusInPlace())).plusInPlace(this.beta)));
    }
}
