package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.SumVertex;

/* loaded from: input_file:io/improbable/keanu/distributions/continuous/Dirichlet.class */
public class Dirichlet implements ContinuousDistribution {
    private static final double EPSILON = 1.0E-5d;
    private final DoubleTensor concentration;

    public static ContinuousDistribution withParameters(DoubleTensor doubleTensor) {
        return new Dirichlet(doubleTensor);
    }

    private Dirichlet(DoubleTensor doubleTensor) {
        this.concentration = doubleTensor;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        return normalise(Gamma.withParameters(DoubleTensor.ones(jArr), this.concentration).sample(this.concentration.getShape(), keanuRandom));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(DoubleTensor doubleTensor) {
        if (Math.abs(((Double) doubleTensor.sum()).doubleValue() - 1.0d) > EPSILON) {
            throw new IllegalArgumentException("Sum of values to calculate Dirichlet likelihood for must equal 1");
        }
        double doubleValue = ((Double) ((DoubleTensor) this.concentration.minus2(1.0d).timesInPlace(doubleTensor.log2())).sum()).doubleValue();
        double doubleValue2 = ((Double) this.concentration.logGamma().sum()).doubleValue();
        return DoubleTensor.scalar((doubleValue - doubleValue2) + org.apache.commons.math3.special.Gamma.logGamma(((Double) this.concentration.sum()).doubleValue()));
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex doublePlaceholderVertex, DoublePlaceholderVertex doublePlaceholderVertex2) {
        doublePlaceholderVertex.sum().minus2(1.0d).abs().lessThanOrEqualTo(ConstantVertex.of(EPSILON)).assertTrue("Sum of values to calculate Dirichlet likelihood for must equal 1");
        SumVertex sum = doublePlaceholderVertex2.minus2(1.0d).times((DoubleVertex) doublePlaceholderVertex.log2()).sum();
        SumVertex sum2 = doublePlaceholderVertex2.logGamma().sum();
        return sum.minus((DoubleVertex) sum2).plus((DoubleVertex) doublePlaceholderVertex2.sum().logGamma());
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.distributions.ContinuousDistribution
    public Diffs dLogProb(DoubleTensor doubleTensor) {
        DoubleTensor doubleTensor2 = (DoubleTensor) ((DoubleTensor) doubleTensor.log2().minusInPlace(this.concentration.digamma())).plusInPlace((DoubleTensor) Double.valueOf(org.apache.commons.math3.special.Gamma.digamma(((Double) this.concentration.sum()).doubleValue())));
        return new Diffs().put(Diffs.C, doubleTensor2).put(Diffs.X, (DoubleTensor) this.concentration.minus2(1.0d).divInPlace(doubleTensor));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private DoubleTensor normalise(DoubleTensor doubleTensor) {
        return doubleTensor.div2(((Double) doubleTensor.sum()).doubleValue());
    }
}
