package io.improbable.keanu.vertices.intgr.probabilistic;

import com.google.common.base.Preconditions;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.discrete.Multinomial;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:io/improbable/keanu/vertices/intgr/probabilistic/MultinomialVertex.class */
public class MultinomialVertex extends IntegerVertex implements ProbabilisticInteger, SamplableWithManyScalars<IntegerTensor> {
    private final DoubleVertex p;
    private final IntegerVertex n;
    private boolean validationEnabled;
    private static final String P_NAME = "p";
    private static final String N_NAME = "n";

    public MultinomialVertex(@LoadShape long[] jArr, @LoadVertexParam("n") IntegerVertex integerVertex, @LoadVertexParam("p") DoubleVertex doubleVertex) {
        super(jArr);
        Preconditions.checkArgument(Arrays.equals(calculateExpectedShape(integerVertex.getShape(), doubleVertex.getShape()), jArr));
        this.p = doubleVertex;
        this.n = integerVertex;
        this.validationEnabled = true;
        setParents(doubleVertex, integerVertex);
    }

    private static long[] calculateExpectedShape(long[] jArr, long[] jArr2) {
        int length = jArr2.length;
        long j = jArr2[length - 1];
        Preconditions.checkArgument(j >= 2, "K value of " + j + " must be greater than 1");
        long[] selectDimensions = TensorShape.selectDimensions(0, length - 1, jArr2);
        if (TensorShapeValidation.isBroadcastable(jArr, selectDimensions)) {
            return TensorShape.concat(TensorShapeValidation.checkIsBroadcastable(jArr, selectDimensions), new long[]{j});
        }
        throw new IllegalArgumentException("The shape of n " + Arrays.toString(jArr) + " must be broadcastable with the shape of p excluding the k dimension " + Arrays.toString(selectDimensions));
    }

    @ExportVertexToPythonBindings
    public MultinomialVertex(IntegerVertex integerVertex, DoubleVertex doubleVertex) {
        this(calculateExpectedShape(integerVertex.getShape(), doubleVertex.getShape()), integerVertex, doubleVertex);
    }

    public MultinomialVertex(int i, DoubleVertex doubleVertex) {
        this(doubleVertex.getShape(), ConstantVertex.of(IntegerTensor.scalar(i)), doubleVertex);
    }

    public MultinomialVertex(int i, DoubleTensor doubleTensor) {
        this(doubleTensor.getShape(), ConstantVertex.of(IntegerTensor.scalar(i)), ConstantVertex.of(doubleTensor));
    }

    public MultinomialVertex(IntegerTensor integerTensor, DoubleVertex doubleVertex) {
        this(ConstantVertex.of(integerTensor), doubleVertex);
    }

    public MultinomialVertex(IntegerTensor integerTensor, DoubleTensor doubleTensor) {
        this(ConstantVertex.of(integerTensor), ConstantVertex.of(doubleTensor));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.vertices.Probabilistic
    public double logProb(IntegerTensor integerTensor) {
        return ((Double) Multinomial.withParameters(this.n.getValue(), this.p.getValue(), this.validationEnabled).logProb(integerTensor).sum()).doubleValue();
    }

    /* renamed from: dLogProb, reason: avoid collision after fix types in other method */
    public Map<Vertex, DoubleTensor> dLogProb2(IntegerTensor integerTensor, Set<? extends Vertex> set) {
        throw new UnsupportedOperationException();
    }

    @Override // io.improbable.keanu.vertices.SamplableWithShape
    public IntegerTensor sampleWithShape(long[] jArr, KeanuRandom keanuRandom) {
        return Multinomial.withParameters(this.n.getValue(), this.p.getValue(), this.validationEnabled).sample(jArr, keanuRandom);
    }

    @SaveVertexParam(P_NAME)
    public DoubleVertex getP() {
        return this.p;
    }

    @SaveVertexParam(N_NAME)
    public IntegerVertex getN() {
        return this.n;
    }

    public boolean isValidationEnabled() {
        return this.validationEnabled;
    }

    public void setValidationEnabled(boolean z) {
        this.validationEnabled = z;
    }

    @Override // io.improbable.keanu.vertices.Probabilistic
    public /* bridge */ /* synthetic */ Map dLogProb(IntegerTensor integerTensor, Set set) {
        return dLogProb2(integerTensor, (Set<? extends Vertex>) set);
    }
}
