package io.improbable.keanu.vertices.intgr.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.discrete.UniformInt;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.LogProbGraphSupplier;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.intgr.IntegerPlaceholderVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import io.improbable.keanu.vertices.intgr.nonprobabilistic.ConstantIntegerVertex;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:io/improbable/keanu/vertices/intgr/probabilistic/UniformIntVertex.class */
public class UniformIntVertex extends IntegerVertex implements ProbabilisticInteger, SamplableWithManyScalars<IntegerTensor>, LogProbGraphSupplier {
    private IntegerVertex min;
    private IntegerVertex max;
    private static final String MIN_NAME = "min";
    private static final String MAX_NAME = "max";

    /* JADX WARN: Type inference failed for: r1v2, types: [long[], long[][]] */
    public UniformIntVertex(@LoadShape long[] jArr, @LoadVertexParam("min") IntegerVertex integerVertex, @LoadVertexParam("max") IntegerVertex integerVertex2) {
        super(jArr);
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(jArr, new long[]{integerVertex.getShape(), integerVertex2.getShape()});
        this.min = integerVertex;
        this.max = integerVertex2;
        setParents(integerVertex, integerVertex2);
    }

    public UniformIntVertex(long[] jArr, int i, int i2) {
        this(jArr, new ConstantIntegerVertex(i), new ConstantIntegerVertex(i2));
    }

    public UniformIntVertex(long[] jArr, IntegerTensor integerTensor, IntegerTensor integerTensor2) {
        this(jArr, new ConstantIntegerVertex(integerTensor), new ConstantIntegerVertex(integerTensor2));
    }

    public UniformIntVertex(long[] jArr, IntegerVertex integerVertex, int i) {
        this(jArr, integerVertex, new ConstantIntegerVertex(i));
    }

    public UniformIntVertex(long[] jArr, int i, IntegerVertex integerVertex) {
        this(jArr, new ConstantIntegerVertex(i), integerVertex);
    }

    /* JADX WARN: Type inference failed for: r1v1, types: [long[], long[][]] */
    @ExportVertexToPythonBindings
    public UniformIntVertex(IntegerVertex integerVertex, IntegerVertex integerVertex2) {
        this(TensorShapeValidation.checkHasOneNonLengthOneShapeOrAllLengthOne(new long[]{integerVertex.getShape(), integerVertex2.getShape()}), integerVertex, integerVertex2);
    }

    public UniformIntVertex(IntegerVertex integerVertex, int i) {
        this(integerVertex.getShape(), integerVertex, new ConstantIntegerVertex(i));
    }

    public UniformIntVertex(int i, IntegerVertex integerVertex) {
        this(integerVertex.getShape(), new ConstantIntegerVertex(i), integerVertex);
    }

    public UniformIntVertex(int i, int i2) {
        this(Tensor.SCALAR_SHAPE, new ConstantIntegerVertex(i), new ConstantIntegerVertex(i2));
    }

    @SaveVertexParam(MIN_NAME)
    public IntegerVertex getMin() {
        return this.min;
    }

    @SaveVertexParam(MAX_NAME)
    public IntegerVertex getMax() {
        return this.max;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.vertices.Probabilistic
    public double logProb(IntegerTensor integerTensor) {
        return ((Double) UniformInt.withParameters(this.min.getValue(), this.max.getValue()).logProb(integerTensor).sum()).doubleValue();
    }

    @Override // io.improbable.keanu.vertices.LogProbGraphSupplier
    public LogProbGraph logProbGraph() {
        IntegerPlaceholderVertex integerPlaceholderVertex = new IntegerPlaceholderVertex(getShape());
        IntegerPlaceholderVertex integerPlaceholderVertex2 = new IntegerPlaceholderVertex(this.min.getShape());
        IntegerPlaceholderVertex integerPlaceholderVertex3 = new IntegerPlaceholderVertex(this.max.getShape());
        return LogProbGraph.builder().input(this, integerPlaceholderVertex).input(this.min, integerPlaceholderVertex2).input(this.max, integerPlaceholderVertex3).logProbOutput(UniformInt.logProbOutput(integerPlaceholderVertex, integerPlaceholderVertex2, integerPlaceholderVertex3)).build();
    }

    /* renamed from: dLogProb, reason: avoid collision after fix types in other method */
    public Map<Vertex, DoubleTensor> dLogProb2(IntegerTensor integerTensor, Set<? extends Vertex> set) {
        throw new UnsupportedOperationException();
    }

    @Override // io.improbable.keanu.vertices.SamplableWithShape
    public IntegerTensor sampleWithShape(long[] jArr, KeanuRandom keanuRandom) {
        return UniformInt.withParameters(this.min.getValue(), this.max.getValue()).sample(jArr, keanuRandom);
    }

    @Override // io.improbable.keanu.vertices.Probabilistic
    public /* bridge */ /* synthetic */ Map dLogProb(IntegerTensor integerTensor, Set set) {
        return dLogProb2(integerTensor, (Set<? extends Vertex>) set);
    }
}
