/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.discrete;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.DiscreteDistribution;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.AdditionVertex;
import io.improbable.keanu.vertices.intgr.IntegerPlaceholderVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;

public class Geometric
implements DiscreteDistribution {
    private final DoubleTensor p;

    public static DiscreteDistribution withParameters(DoubleTensor p) {
        return new Geometric(p);
    }

    private Geometric(DoubleTensor p) {
        this.p = p;
    }

    @Override
    public IntegerTensor sample(long[] shape, KeanuRandom random) {
        DoubleTensor numerator = (DoubleTensor)random.nextDouble(shape).logInPlace();
        DoubleTensor denominator = (DoubleTensor)((DoubleTensor)((Object)((DoubleTensor)this.p.unaryMinus()).plusInPlace(1.0))).logInPlace();
        return (IntegerTensor)((Object)((DoubleTensor)numerator.divInPlace(denominator).floorInPlace()).toInteger().plusInPlace(1));
    }

    @Override
    public DoubleTensor logProb(IntegerTensor k) {
        if (!this.checkParameterIsValid()) {
            return DoubleTensor.create(Double.NEGATIVE_INFINITY, k.getShape());
        }
        return this.calculateLogProb(k);
    }

    public static DoubleVertex logProbOutput(IntegerPlaceholderVertex k, DoublePlaceholderVertex p) {
        ConstantDoubleVertex zeroes = ConstantVertex.of(DoubleTensor.zeros(k.getShape()));
        ConstantDoubleVertex ones = ConstantVertex.of(DoubleTensor.ones(k.getShape()));
        AdditionVertex parameterIsInvalidMask = p.toGreaterThanMask(zeroes).times(p.toLessThanMask(ones)).unaryMinus().plus(ones);
        return Geometric.calculateLogProb(k, p).setWithMask((DoubleVertex)parameterIsInvalidMask, Double.NEGATIVE_INFINITY);
    }

    private DoubleTensor calculateLogProb(IntegerTensor k) {
        DoubleTensor kAsDouble = k.toDouble();
        DoubleTensor oneMinusP = (DoubleTensor)((Object)((DoubleTensor)this.p.unaryMinus()).plusInPlace(1.0));
        DoubleTensor results = (DoubleTensor)((DoubleTensor)((DoubleTensor)((Object)kAsDouble.minusInPlace(1.0))).timesInPlace(oneMinusP.logInPlace())).plusInPlace(this.p.log());
        return this.setProbToZeroForInvalidK(k, results);
    }

    private static DoubleVertex calculateLogProb(IntegerVertex k, DoubleVertex p) {
        DoubleVertex kAsDouble = k.toDouble();
        AdditionVertex oneMinusP = p.unaryMinus().plus(1.0);
        AdditionVertex results = kAsDouble.minus(1.0).times(oneMinusP.log()).plus(p.log());
        return Geometric.setProbToZeroForInvalidK(k, results);
    }

    private DoubleTensor setProbToZeroForInvalidK(IntegerTensor k, DoubleTensor results) {
        IntegerTensor invalidK = k.lessThanMask(IntegerTensor.create(1, k.getShape()));
        return results.setWithMaskInPlace(invalidK.toDouble(), Double.NEGATIVE_INFINITY);
    }

    private static DoubleVertex setProbToZeroForInvalidK(IntegerVertex k, DoubleVertex results) {
        DoubleVertex invalidK = k.toDouble().toLessThanMask(1.0);
        return results.setWithMask(invalidK, Double.NEGATIVE_INFINITY);
    }

    private boolean checkParameterIsValid() {
        return this.p.greaterThan(0.0).allTrue() && this.p.lessThan(1.0).allTrue();
    }
}

