/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.intgr.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.discrete.UniformInt;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.LogProbGraphSupplier;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.intgr.IntegerPlaceholderVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import io.improbable.keanu.vertices.intgr.nonprobabilistic.ConstantIntegerVertex;
import io.improbable.keanu.vertices.intgr.probabilistic.ProbabilisticInteger;
import java.util.Map;
import java.util.Set;

public class UniformIntVertex
extends IntegerVertex
implements ProbabilisticInteger,
SamplableWithManyScalars<IntegerTensor>,
LogProbGraphSupplier {
    private IntegerVertex min;
    private IntegerVertex max;
    private static final String MIN_NAME = "min";
    private static final String MAX_NAME = "max";

    public UniformIntVertex(@LoadShape long[] shape, @LoadVertexParam(value="min") IntegerVertex min, @LoadVertexParam(value="max") IntegerVertex max) {
        super(shape);
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(shape, min.getShape(), max.getShape());
        this.min = min;
        this.max = max;
        this.setParents(min, max);
    }

    public UniformIntVertex(long[] shape, int min, int max) {
        this(shape, (IntegerVertex)new ConstantIntegerVertex(min), (IntegerVertex)new ConstantIntegerVertex(max));
    }

    public UniformIntVertex(long[] shape, IntegerTensor min, IntegerTensor max) {
        this(shape, (IntegerVertex)new ConstantIntegerVertex(min), (IntegerVertex)new ConstantIntegerVertex(max));
    }

    public UniformIntVertex(long[] shape, IntegerVertex min, int max) {
        this(shape, min, (IntegerVertex)new ConstantIntegerVertex(max));
    }

    public UniformIntVertex(long[] shape, int min, IntegerVertex max) {
        this(shape, (IntegerVertex)new ConstantIntegerVertex(min), max);
    }

    @ExportVertexToPythonBindings
    public UniformIntVertex(IntegerVertex min, IntegerVertex max) {
        this(TensorShapeValidation.checkHasOneNonLengthOneShapeOrAllLengthOne(min.getShape(), max.getShape()), min, max);
    }

    public UniformIntVertex(IntegerVertex min, int max) {
        this(min.getShape(), min, (IntegerVertex)new ConstantIntegerVertex(max));
    }

    public UniformIntVertex(int min, IntegerVertex max) {
        this(max.getShape(), (IntegerVertex)new ConstantIntegerVertex(min), max);
    }

    public UniformIntVertex(int min, int max) {
        this(Tensor.SCALAR_SHAPE, (IntegerVertex)new ConstantIntegerVertex(min), (IntegerVertex)new ConstantIntegerVertex(max));
    }

    @SaveVertexParam(value="min")
    public IntegerVertex getMin() {
        return this.min;
    }

    @SaveVertexParam(value="max")
    public IntegerVertex getMax() {
        return this.max;
    }

    @Override
    public double logProb(IntegerTensor value) {
        return (Double)UniformInt.withParameters((IntegerTensor)this.min.getValue(), (IntegerTensor)this.max.getValue()).logProb(value).sum();
    }

    @Override
    public LogProbGraph logProbGraph() {
        IntegerPlaceholderVertex valuePlaceholder = new IntegerPlaceholderVertex(this.getShape());
        IntegerPlaceholderVertex minPlaceholder = new IntegerPlaceholderVertex(this.min.getShape());
        IntegerPlaceholderVertex maxPlaceholder = new IntegerPlaceholderVertex(this.max.getShape());
        return LogProbGraph.builder().input(this, valuePlaceholder).input(this.min, minPlaceholder).input(this.max, maxPlaceholder).logProbOutput(UniformInt.logProbOutput(valuePlaceholder, minPlaceholder, maxPlaceholder)).build();
    }

    @Override
    public Map<Vertex, DoubleTensor> dLogProb(IntegerTensor value, Set<? extends Vertex> withRespectTo) {
        throw new UnsupportedOperationException();
    }

    @Override
    public IntegerTensor sampleWithShape(long[] shape, KeanuRandom random) {
        return (IntegerTensor)UniformInt.withParameters((IntegerTensor)this.min.getValue(), (IntegerTensor)this.max.getValue()).sample(shape, random);
    }
}

