package io.improbable.keanu.vertices.dbl.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.continuous.Gaussian;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.SumVertex;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/probabilistic/HalfGaussianVertex.class */
public class HalfGaussianVertex extends GaussianVertex {
    private static final double MU_ZERO = 0.0d;
    private static final double LOG_TWO = Math.log(2.0d);

    public HalfGaussianVertex(@LoadShape long[] jArr, @LoadVertexParam("sigma") DoubleVertex doubleVertex) {
        super(jArr, MU_ZERO, doubleVertex);
    }

    @ExportVertexToPythonBindings
    public HalfGaussianVertex(DoubleVertex doubleVertex) {
        super(MU_ZERO, doubleVertex);
    }

    public HalfGaussianVertex(double d) {
        super(MU_ZERO, new ConstantDoubleVertex(d));
    }

    public HalfGaussianVertex(long[] jArr, double d) {
        super(jArr, MU_ZERO, new ConstantDoubleVertex(d));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.vertices.dbl.probabilistic.GaussianVertex, io.improbable.keanu.vertices.Probabilistic
    public double logProb(DoubleTensor doubleTensor) {
        if (doubleTensor.greaterThanOrEqual((DoubleTensor) Double.valueOf(MU_ZERO)).allTrue()) {
            return super.logProb(doubleTensor) + (LOG_TWO * doubleTensor.getLength());
        }
        return Double.NEGATIVE_INFINITY;
    }

    @Override // io.improbable.keanu.vertices.dbl.probabilistic.GaussianVertex, io.improbable.keanu.vertices.LogProbGraphSupplier
    public LogProbGraph logProbGraph() {
        DoublePlaceholderVertex doublePlaceholderVertex = new DoublePlaceholderVertex(getShape());
        DoublePlaceholderVertex doublePlaceholderVertex2 = new DoublePlaceholderVertex(getMu().getShape());
        DoublePlaceholderVertex doublePlaceholderVertex3 = new DoublePlaceholderVertex(getSigma().getShape());
        SumVertex sum = Gaussian.logProbOutput(doublePlaceholderVertex, doublePlaceholderVertex2, doublePlaceholderVertex3).plus2(LOG_TWO).setWithMask(doublePlaceholderVertex.toLessThanMask(MU_ZERO), Double.NEGATIVE_INFINITY).sum();
        doublePlaceholderVertex2.setValue(MU_ZERO);
        return LogProbGraph.builder().input(this, doublePlaceholderVertex).input(getMu(), doublePlaceholderVertex2).input(getSigma(), doublePlaceholderVertex3).logProbOutput(sum).build();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.vertices.dbl.probabilistic.GaussianVertex
    public Map<Vertex, DoubleTensor> dLogProb(DoubleTensor doubleTensor, Set<? extends Vertex> set) {
        Map<Vertex, DoubleTensor> dLogProb2 = super.dLogProb2(doubleTensor, set);
        if (doubleTensor.greaterThanOrEqual((DoubleTensor) Double.valueOf(MU_ZERO)).allTrue()) {
            return dLogProb2;
        }
        for (Map.Entry<Vertex, DoubleTensor> entry : dLogProb2.entrySet()) {
            dLogProb2.put(entry.getKey(), entry.getValue().setWithMaskInPlace(doubleTensor.lessThanMask(DoubleTensor.scalar(MU_ZERO)), Double.valueOf(MU_ZERO)));
        }
        return dLogProb2;
    }

    @Override // io.improbable.keanu.vertices.dbl.probabilistic.GaussianVertex, io.improbable.keanu.vertices.SamplableWithShape
    public DoubleTensor sampleWithShape(long[] jArr, KeanuRandom keanuRandom) {
        return (DoubleTensor) super.sampleWithShape(jArr, keanuRandom).absInPlace();
    }

    @Override // io.improbable.keanu.vertices.dbl.probabilistic.GaussianVertex, io.improbable.keanu.vertices.Probabilistic
    public /* bridge */ /* synthetic */ Map dLogProb(DoubleTensor doubleTensor, Set set) {
        return dLogProb(doubleTensor, (Set<? extends Vertex>) set);
    }
}
