/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.intgr.probabilistic;

import com.google.common.base.Preconditions;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.discrete.Multinomial;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import io.improbable.keanu.vertices.intgr.probabilistic.ProbabilisticInteger;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;

public class MultinomialVertex
extends IntegerVertex
implements ProbabilisticInteger,
SamplableWithManyScalars<IntegerTensor> {
    private final DoubleVertex p;
    private final IntegerVertex n;
    private boolean validationEnabled;
    private static final String P_NAME = "p";
    private static final String N_NAME = "n";

    public MultinomialVertex(@LoadShape long[] tensorShape, @LoadVertexParam(value="n") IntegerVertex n, @LoadVertexParam(value="p") DoubleVertex p) {
        super(tensorShape);
        long[] expectedShape = MultinomialVertex.calculateExpectedShape(n.getShape(), p.getShape());
        Preconditions.checkArgument((boolean)Arrays.equals(expectedShape, tensorShape));
        this.p = p;
        this.n = n;
        this.validationEnabled = true;
        this.setParents(p, n);
    }

    private static long[] calculateExpectedShape(long[] nShape, long[] pShape) {
        int pRank = pShape.length;
        long k = pShape[pRank - 1];
        Preconditions.checkArgument((k >= 2L ? 1 : 0) != 0, (Object)("K value of " + k + " must be greater than 1"));
        long[] pBatchShape = TensorShape.selectDimensions(0, pRank - 1, pShape);
        if (!TensorShapeValidation.isBroadcastable(nShape, pBatchShape)) {
            throw new IllegalArgumentException("The shape of n " + Arrays.toString(nShape) + " must be broadcastable with the shape of p excluding the k dimension " + Arrays.toString(pBatchShape));
        }
        return TensorShape.concat(TensorShapeValidation.checkIsBroadcastable(nShape, pBatchShape), new long[]{k});
    }

    @ExportVertexToPythonBindings
    public MultinomialVertex(IntegerVertex n, DoubleVertex p) {
        this(MultinomialVertex.calculateExpectedShape(n.getShape(), p.getShape()), n, p);
    }

    public MultinomialVertex(int n, DoubleVertex p) {
        this(p.getShape(), ConstantVertex.of(IntegerTensor.scalar(n)), p);
    }

    public MultinomialVertex(int n, DoubleTensor p) {
        this(p.getShape(), ConstantVertex.of(IntegerTensor.scalar(n)), ConstantVertex.of(p));
    }

    public MultinomialVertex(IntegerTensor n, DoubleVertex p) {
        this(ConstantVertex.of(n), p);
    }

    public MultinomialVertex(IntegerTensor n, DoubleTensor p) {
        this(ConstantVertex.of(n), (DoubleVertex)ConstantVertex.of(p));
    }

    @Override
    public double logProb(IntegerTensor xTensor) {
        return (Double)Multinomial.withParameters((IntegerTensor)this.n.getValue(), (DoubleTensor)this.p.getValue(), this.validationEnabled).logProb(xTensor).sum();
    }

    @Override
    public Map<Vertex, DoubleTensor> dLogProb(IntegerTensor value, Set<? extends Vertex> withRespectTo) {
        throw new UnsupportedOperationException();
    }

    @Override
    public IntegerTensor sampleWithShape(long[] shape, KeanuRandom random) {
        return Multinomial.withParameters((IntegerTensor)this.n.getValue(), (DoubleTensor)this.p.getValue(), this.validationEnabled).sample(shape, random);
    }

    @SaveVertexParam(value="p")
    public DoubleVertex getP() {
        return this.p;
    }

    @SaveVertexParam(value="n")
    public IntegerVertex getN() {
        return this.n;
    }

    public boolean isValidationEnabled() {
        return this.validationEnabled;
    }

    public void setValidationEnabled(boolean validationEnabled) {
        this.validationEnabled = validationEnabled;
    }
}

