/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.continuous.Gamma;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogGammaVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.SumVertex;

public class Dirichlet
implements ContinuousDistribution {
    private static final double EPSILON = 1.0E-5;
    private final DoubleTensor concentration;

    public static ContinuousDistribution withParameters(DoubleTensor concentration) {
        return new Dirichlet(concentration);
    }

    private Dirichlet(DoubleTensor concentration) {
        this.concentration = concentration;
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        ContinuousDistribution gamma = Gamma.withParameters(DoubleTensor.ones(shape), this.concentration);
        DoubleTensor gammaSamples = (DoubleTensor)gamma.sample(this.concentration.getShape(), random);
        return this.normalise(gammaSamples);
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        if (Math.abs((Double)x.sum() - 1.0) > 1.0E-5) {
            throw new IllegalArgumentException("Sum of values to calculate Dirichlet likelihood for must equal 1");
        }
        double sumConcentrationLogged = (Double)((DoubleTensor)this.concentration.minus(1.0).timesInPlace(x.log())).sum();
        double sumLogGammaConcentration = (Double)((DoubleTensor)this.concentration.logGamma()).sum();
        double logGammaSumConcentration = org.apache.commons.math3.special.Gamma.logGamma((double)((Double)this.concentration.sum()));
        return DoubleTensor.scalar(sumConcentrationLogged - sumLogGammaConcentration + logGammaSumConcentration);
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex concentration) {
        BooleanVertex xMinusOneIsLessThanOrEqualToEpsilon = x.sum().minus(1.0).abs().lessThanOrEqualTo(ConstantVertex.of(1.0E-5));
        xMinusOneIsLessThanOrEqualToEpsilon.assertTrue("Sum of values to calculate Dirichlet likelihood for must equal 1");
        SumVertex sumConcentrationLogged = concentration.minus(1.0).times(x.log()).sum();
        SumVertex sumLogGammaConcentration = concentration.logGamma().sum();
        LogGammaVertex logGammaSumConcentration = concentration.sum().logGamma();
        return sumConcentrationLogged.minus(sumLogGammaConcentration).plus(logGammaSumConcentration);
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        DoubleTensor dLogPdc = (DoubleTensor)((Object)((DoubleTensor)((DoubleTensor)x.log()).minusInPlace(this.concentration.digamma())).plusInPlace(org.apache.commons.math3.special.Gamma.digamma((double)((Double)this.concentration.sum()))));
        DoubleTensor dLogPdx = this.concentration.minus(1.0).divInPlace(x);
        return new Diffs().put(Diffs.C, dLogPdc).put(Diffs.X, dLogPdx);
    }

    private DoubleTensor normalise(DoubleTensor gammaSamples) {
        double sum = (Double)gammaSamples.sum();
        return gammaSamples.div(sum);
    }
}

