package io.improbable.keanu.vertices.dbl.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.continuous.Triangular;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.LogProbGraphSupplier;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/probabilistic/TriangularVertex.class */
public class TriangularVertex extends DoubleVertex implements Differentiable, ProbabilisticDouble, SamplableWithManyScalars<DoubleTensor>, LogProbGraphSupplier {
    private final DoubleVertex xMin;
    private final DoubleVertex xMax;
    private final DoubleVertex c;
    private static final String X_MIN_NAME = "xMin";
    private static final String X_MAX_NAME = "xMax";
    private static final String C_NAME = "c";

    /* JADX WARN: Type inference failed for: r1v2, types: [long[], long[][]] */
    public TriangularVertex(@LoadShape long[] jArr, @LoadVertexParam("xMin") DoubleVertex doubleVertex, @LoadVertexParam("xMax") DoubleVertex doubleVertex2, @LoadVertexParam("c") DoubleVertex doubleVertex3) {
        super(jArr);
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(jArr, new long[]{doubleVertex.getShape(), doubleVertex2.getShape(), doubleVertex3.getShape()});
        this.xMin = doubleVertex;
        this.xMax = doubleVertex2;
        this.c = doubleVertex3;
        setParents(doubleVertex, doubleVertex2, doubleVertex3);
    }

    public TriangularVertex(long[] jArr, DoubleVertex doubleVertex, DoubleVertex doubleVertex2, double d) {
        this(jArr, doubleVertex, doubleVertex2, new ConstantDoubleVertex(d));
    }

    public TriangularVertex(long[] jArr, DoubleVertex doubleVertex, double d, DoubleVertex doubleVertex2) {
        this(jArr, doubleVertex, new ConstantDoubleVertex(d), doubleVertex2);
    }

    public TriangularVertex(long[] jArr, DoubleVertex doubleVertex, double d, double d2) {
        this(jArr, doubleVertex, new ConstantDoubleVertex(d), new ConstantDoubleVertex(d2));
    }

    public TriangularVertex(long[] jArr, double d, DoubleVertex doubleVertex, DoubleVertex doubleVertex2) {
        this(jArr, new ConstantDoubleVertex(d), doubleVertex, doubleVertex2);
    }

    public TriangularVertex(long[] jArr, double d, double d2, DoubleVertex doubleVertex) {
        this(jArr, new ConstantDoubleVertex(d), new ConstantDoubleVertex(d2), doubleVertex);
    }

    public TriangularVertex(long[] jArr, double d, double d2, double d3) {
        this(jArr, new ConstantDoubleVertex(d), new ConstantDoubleVertex(d2), new ConstantDoubleVertex(d3));
    }

    /* JADX WARN: Type inference failed for: r1v1, types: [long[], long[][]] */
    @ExportVertexToPythonBindings
    public TriangularVertex(DoubleVertex doubleVertex, DoubleVertex doubleVertex2, DoubleVertex doubleVertex3) {
        this(TensorShapeValidation.checkHasOneNonLengthOneShapeOrAllLengthOne(new long[]{doubleVertex.getShape(), doubleVertex2.getShape(), doubleVertex3.getShape()}), doubleVertex, doubleVertex2, doubleVertex3);
    }

    public TriangularVertex(DoubleVertex doubleVertex, DoubleVertex doubleVertex2, double d) {
        this(doubleVertex, doubleVertex2, new ConstantDoubleVertex(d));
    }

    public TriangularVertex(DoubleVertex doubleVertex, double d, DoubleVertex doubleVertex2) {
        this(doubleVertex, new ConstantDoubleVertex(d), doubleVertex2);
    }

    public TriangularVertex(DoubleVertex doubleVertex, double d, double d2) {
        this(doubleVertex, new ConstantDoubleVertex(d), new ConstantDoubleVertex(d2));
    }

    public TriangularVertex(double d, DoubleVertex doubleVertex, DoubleVertex doubleVertex2) {
        this(new ConstantDoubleVertex(d), doubleVertex, doubleVertex2);
    }

    public TriangularVertex(double d, double d2, DoubleVertex doubleVertex) {
        this(new ConstantDoubleVertex(d), new ConstantDoubleVertex(d2), doubleVertex);
    }

    public TriangularVertex(double d, double d2, double d3) {
        this(new ConstantDoubleVertex(d), new ConstantDoubleVertex(d2), new ConstantDoubleVertex(d3));
    }

    @SaveVertexParam(X_MIN_NAME)
    public DoubleVertex getXMin() {
        return this.xMin;
    }

    @SaveVertexParam(X_MAX_NAME)
    public DoubleVertex getXMax() {
        return this.xMax;
    }

    @SaveVertexParam(C_NAME)
    public DoubleVertex getC() {
        return this.c;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.vertices.Probabilistic
    public double logProb(DoubleTensor doubleTensor) {
        return ((Double) Triangular.withParameters(this.xMin.getValue(), this.xMax.getValue(), this.c.getValue()).logProb(doubleTensor).sum()).doubleValue();
    }

    @Override // io.improbable.keanu.vertices.LogProbGraphSupplier
    public LogProbGraph logProbGraph() {
        DoublePlaceholderVertex doublePlaceholderVertex = new DoublePlaceholderVertex(getShape());
        DoublePlaceholderVertex doublePlaceholderVertex2 = new DoublePlaceholderVertex(this.xMin.getShape());
        DoublePlaceholderVertex doublePlaceholderVertex3 = new DoublePlaceholderVertex(this.xMax.getShape());
        DoublePlaceholderVertex doublePlaceholderVertex4 = new DoublePlaceholderVertex(this.c.getShape());
        return LogProbGraph.builder().input(this, doublePlaceholderVertex).input(this.xMin, doublePlaceholderVertex2).input(this.xMax, doublePlaceholderVertex3).input(this.c, doublePlaceholderVertex4).logProbOutput(Triangular.logProbOutput(doublePlaceholderVertex, doublePlaceholderVertex2, doublePlaceholderVertex3, doublePlaceholderVertex4)).build();
    }

    /* renamed from: dLogProb, reason: avoid collision after fix types in other method */
    public Map<Vertex, DoubleTensor> dLogProb2(DoubleTensor doubleTensor, Set<? extends Vertex> set) {
        throw new UnsupportedOperationException();
    }

    @Override // io.improbable.keanu.vertices.SamplableWithShape
    public DoubleTensor sampleWithShape(long[] jArr, KeanuRandom keanuRandom) {
        return Triangular.withParameters(this.xMin.getValue(), this.xMax.getValue(), this.c.getValue()).sample(jArr, keanuRandom);
    }

    @Override // io.improbable.keanu.vertices.Probabilistic
    public /* bridge */ /* synthetic */ Map dLogProb(DoubleTensor doubleTensor, Set set) {
        return dLogProb2(doubleTensor, (Set<? extends Vertex>) set);
    }
}
