package io.improbable.keanu.distributions.continuous;

import com.google.common.base.Preconditions;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.MultiplicationVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogGammaVertex;

/* loaded from: input_file:io/improbable/keanu/distributions/continuous/Beta.class */
public class Beta implements ContinuousDistribution {
    private final DoubleTensor alpha;
    private final DoubleTensor beta;
    private final DoubleTensor xMin;
    private final DoubleTensor xMax;

    public static ContinuousDistribution withParameters(DoubleTensor doubleTensor, DoubleTensor doubleTensor2, DoubleTensor doubleTensor3, DoubleTensor doubleTensor4) {
        return new Beta(doubleTensor, doubleTensor2, doubleTensor3, doubleTensor4);
    }

    private Beta(DoubleTensor doubleTensor, DoubleTensor doubleTensor2, DoubleTensor doubleTensor3, DoubleTensor doubleTensor4) {
        this.alpha = doubleTensor;
        this.beta = doubleTensor2;
        this.xMin = doubleTensor3;
        this.xMax = doubleTensor4;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        Preconditions.checkArgument(this.alpha.greaterThan((DoubleTensor) Double.valueOf(0.0d)).allTrue() && this.beta.greaterThan((DoubleTensor) Double.valueOf(0.0d)).allTrue(), "alpha and beta must be positive. alpha: " + this.alpha + " beta: " + this.beta);
        DoubleTensor nextGamma = keanuRandom.nextGamma(jArr, DoubleTensor.scalar(1.0d), this.alpha);
        DoubleTensor nextGamma2 = keanuRandom.nextGamma(jArr, DoubleTensor.scalar(1.0d), this.beta);
        DoubleTensor doubleTensor = (DoubleTensor) this.xMax.minus(this.xMin);
        DoubleTensor doubleTensor2 = (DoubleTensor) nextGamma.plus(nextGamma2);
        return (DoubleTensor) ((DoubleTensor) ((DoubleTensor) this.alpha.lessThanMask(this.beta)).timesInPlace((DoubleTensor) this.xMax.minus((DoubleTensor) ((DoubleTensor) nextGamma2.div(doubleTensor2)).timesInPlace(doubleTensor)))).plusInPlace((DoubleTensor) ((DoubleTensor) this.alpha.greaterThanOrEqualToMask(this.beta)).timesInPlace((DoubleTensor) this.xMin.plus((DoubleTensor) ((DoubleTensor) nextGamma.div(doubleTensor2)).timesInPlace(doubleTensor))));
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(DoubleTensor doubleTensor) {
        DoubleTensor logGamma = this.alpha.logGamma();
        DoubleTensor logGamma2 = this.beta.logGamma();
        DoubleTensor logGammaInPlace = ((DoubleTensor) this.alpha.plus(this.beta)).logGammaInPlace();
        DoubleTensor doubleTensor2 = (DoubleTensor) doubleTensor.log2().timesInPlace(this.alpha.minus2(1.0d));
        DoubleTensor doubleTensor3 = (DoubleTensor) ((DoubleTensor) ((DoubleTensor) doubleTensor.unaryMinus()).plusInPlace((DoubleTensor) Double.valueOf(1.0d))).logInPlace().timesInPlace(this.beta.minus2(1.0d));
        return (DoubleTensor) ((DoubleTensor) doubleTensor2.plusInPlace(doubleTensor3)).minusInPlace((DoubleTensor) ((DoubleTensor) logGamma.plusInPlace(logGamma2)).minusInPlace(logGammaInPlace));
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex doublePlaceholderVertex, DoublePlaceholderVertex doublePlaceholderVertex2, DoublePlaceholderVertex doublePlaceholderVertex3) {
        LogGammaVertex logGamma = doublePlaceholderVertex2.logGamma();
        LogGammaVertex logGamma2 = doublePlaceholderVertex3.logGamma();
        LogGammaVertex logGamma3 = doublePlaceholderVertex2.plus((DoubleVertex) doublePlaceholderVertex3).logGamma();
        MultiplicationVertex times = doublePlaceholderVertex.log2().times((DoubleVertex) doublePlaceholderVertex2.minus2(1.0d));
        MultiplicationVertex times2 = doublePlaceholderVertex.unaryMinus().plus2(1.0d).log2().times((DoubleVertex) doublePlaceholderVertex3.minus2(1.0d));
        return times.plus((DoubleVertex) times2).minus((DoubleVertex) logGamma.plus((DoubleVertex) logGamma2).minus((DoubleVertex) logGamma3));
    }

    @Override // io.improbable.keanu.distributions.ContinuousDistribution
    public Diffs dLogProb(DoubleTensor doubleTensor) {
        DoubleTensor doubleTensor2 = (DoubleTensor) ((DoubleTensor) doubleTensor.unaryMinus()).plusInPlace((DoubleTensor) Double.valueOf(1.0d));
        DoubleTensor digammaInPlace = ((DoubleTensor) this.alpha.plus(this.beta)).digammaInPlace();
        DoubleTensor doubleTensor3 = (DoubleTensor) ((DoubleTensor) doubleTensor.reciprocal().timesInPlace(this.alpha.minus2(1.0d))).minusInPlace((DoubleTensor) doubleTensor2.reciprocal().timesInPlace(this.beta.minus2(1.0d)));
        DoubleTensor doubleTensor4 = (DoubleTensor) doubleTensor.log2().plusInPlace((DoubleTensor) digammaInPlace.minus(this.alpha.digamma()));
        return new Diffs().put(Diffs.A, doubleTensor4).put(Diffs.B, (DoubleTensor) doubleTensor2.logInPlace().plusInPlace((DoubleTensor) digammaInPlace.minusInPlace(this.beta.digamma()))).put(Diffs.X, doubleTensor3);
    }
}
