/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.AdditionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DifferenceVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DivisionVertex;
import io.improbable.keanu.vertices.intgr.IntegerPlaceholderVertex;

public class ChiSquared
implements ContinuousDistribution {
    private static final double LOG_TWO = Math.log(2.0);
    private final IntegerTensor k;

    public static ContinuousDistribution withParameters(IntegerTensor k) {
        return new ChiSquared(k);
    }

    private ChiSquared(IntegerTensor k) {
        this.k = k;
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        return random.nextGamma(shape, DoubleTensor.scalar(2.0), this.k.toDouble().div(2.0));
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        DoubleTensor halfK = this.k.toDouble().div(2.0);
        DoubleTensor numerator = ((DoubleTensor)halfK.minus(1.0).timesInPlace(x.log())).minusInPlace(x.div(2.0));
        DoubleTensor denominator = (DoubleTensor)halfK.times(LOG_TWO).plusInPlace(halfK.logGamma());
        return numerator.minusInPlace(denominator);
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, IntegerPlaceholderVertex k) {
        DivisionVertex halfK = k.toDouble().div(2.0);
        DifferenceVertex numerator = halfK.minus(1.0).times(x.log()).minus(x.div(2.0));
        AdditionVertex denominator = halfK.times(LOG_TWO).plus(halfK.logGamma());
        return numerator.minus(denominator);
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        throw new UnsupportedOperationException();
    }
}

