/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import com.google.common.base.Preconditions;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DifferenceVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.MultiplicationVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogGammaVertex;

public class Beta
implements ContinuousDistribution {
    private final DoubleTensor alpha;
    private final DoubleTensor beta;
    private final DoubleTensor xMin;
    private final DoubleTensor xMax;

    public static ContinuousDistribution withParameters(DoubleTensor alpha, DoubleTensor beta, DoubleTensor xMin, DoubleTensor xMax) {
        return new Beta(alpha, beta, xMin, xMax);
    }

    private Beta(DoubleTensor alpha, DoubleTensor beta, DoubleTensor xMin, DoubleTensor xMax) {
        this.alpha = alpha;
        this.beta = beta;
        this.xMin = xMin;
        this.xMax = xMax;
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        Preconditions.checkArgument((this.alpha.greaterThan(0.0).allTrue() && this.beta.greaterThan(0.0).allTrue() ? 1 : 0) != 0, (Object)("alpha and beta must be positive. alpha: " + this.alpha + " beta: " + this.beta));
        DoubleTensor y1 = random.nextGamma(shape, DoubleTensor.scalar(1.0), this.alpha);
        DoubleTensor y2 = random.nextGamma(shape, DoubleTensor.scalar(1.0), this.beta);
        DoubleTensor range = this.xMax.minus(this.xMin);
        DoubleTensor y1PlusY2 = y1.plus(y2);
        DoubleTensor lessThan = this.xMax.minus(y2.div(y1PlusY2).timesInPlace(range));
        DoubleTensor greaterThan = this.xMin.plus(y1.div(y1PlusY2).timesInPlace(range));
        DoubleTensor lessMask = this.alpha.lessThanMask(this.beta);
        DoubleTensor greaterMask = this.alpha.greaterThanOrEqualToMask(this.beta);
        return lessMask.timesInPlace(lessThan).plusInPlace(greaterMask.timesInPlace(greaterThan));
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        DoubleTensor lnGammaAlpha = (DoubleTensor)this.alpha.logGamma();
        DoubleTensor lnGammaBeta = (DoubleTensor)this.beta.logGamma();
        DoubleTensor alphaPlusBetaLnGamma = (DoubleTensor)this.alpha.plus(this.beta).logGammaInPlace();
        DoubleTensor alphaMinusOneTimesLnX = ((DoubleTensor)x.log()).timesInPlace(this.alpha.minus(1.0));
        DoubleTensor betaMinusOneTimesOneMinusXLn = ((DoubleTensor)((DoubleTensor)((Object)((DoubleTensor)x.unaryMinus()).plusInPlace(1.0))).logInPlace()).timesInPlace(this.beta.minus(1.0));
        DoubleTensor betaFunction = lnGammaAlpha.plusInPlace(lnGammaBeta).minusInPlace(alphaPlusBetaLnGamma);
        return alphaMinusOneTimesLnX.plusInPlace(betaMinusOneTimesOneMinusXLn).minusInPlace(betaFunction);
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex alpha, DoublePlaceholderVertex beta) {
        LogGammaVertex lnGammaAlpha = alpha.logGamma();
        LogGammaVertex lnGammaBeta = beta.logGamma();
        LogGammaVertex alphaPlusBetaLnGamma = alpha.plus(beta).logGamma();
        MultiplicationVertex alphaMinusOneTimesLnX = x.log().times(alpha.minus(1.0));
        MultiplicationVertex betaMinusOneTimesOneMinusXLn = x.unaryMinus().plus(1.0).log().times(beta.minus(1.0));
        DifferenceVertex betaFunction = lnGammaAlpha.plus(lnGammaBeta).minus(alphaPlusBetaLnGamma);
        return alphaMinusOneTimesLnX.plus(betaMinusOneTimesOneMinusXLn).minus(betaFunction);
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        DoubleTensor oneMinusX = (DoubleTensor)((Object)((DoubleTensor)x.unaryMinus()).plusInPlace(1.0));
        DoubleTensor digammaAlphaPlusBeta = (DoubleTensor)this.alpha.plus(this.beta).digammaInPlace();
        DoubleTensor alphaMinusOneDivX = ((DoubleTensor)x.reciprocal()).timesInPlace(this.alpha.minus(1.0));
        DoubleTensor dLogPdx = alphaMinusOneDivX.minusInPlace(((DoubleTensor)oneMinusX.reciprocal()).timesInPlace(this.beta.minus(1.0)));
        DoubleTensor dLogPda = (DoubleTensor)((DoubleTensor)x.log()).plusInPlace(digammaAlphaPlusBeta.minus(this.alpha.digamma()));
        DoubleTensor dLogPdb = (DoubleTensor)((DoubleTensor)oneMinusX.logInPlace()).plusInPlace(digammaAlphaPlusBeta.minusInPlace(this.beta.digamma()));
        return new Diffs().put(Diffs.A, dLogPda).put(Diffs.B, dLogPdb).put(Diffs.X, dLogPdx);
    }
}

