/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.continuous.Gaussian;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.AdditionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.SumVertex;
import io.improbable.keanu.vertices.dbl.probabilistic.GaussianVertex;
import java.util.Map;
import java.util.Set;

public class HalfGaussianVertex
extends GaussianVertex {
    private static final double MU_ZERO = 0.0;
    private static final double LOG_TWO = Math.log(2.0);

    public HalfGaussianVertex(@LoadShape long[] tensorShape, @LoadVertexParam(value="sigma") DoubleVertex sigma) {
        super(tensorShape, 0.0, sigma);
    }

    @ExportVertexToPythonBindings
    public HalfGaussianVertex(DoubleVertex sigma) {
        super(0.0, sigma);
    }

    public HalfGaussianVertex(double sigma) {
        super(0.0, (DoubleVertex)new ConstantDoubleVertex(sigma));
    }

    public HalfGaussianVertex(long[] tensorShape, double sigma) {
        super(tensorShape, 0.0, (DoubleVertex)new ConstantDoubleVertex(sigma));
    }

    @Override
    public double logProb(DoubleTensor value) {
        if (value.greaterThanOrEqual(0.0).allTrue()) {
            return super.logProb(value) + LOG_TWO * (double)value.getLength();
        }
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public LogProbGraph logProbGraph() {
        DoublePlaceholderVertex xPlaceholder = new DoublePlaceholderVertex(this.getShape());
        DoublePlaceholderVertex muPlaceholder = new DoublePlaceholderVertex(this.getMu().getShape());
        DoublePlaceholderVertex sigmaPlaceholder = new DoublePlaceholderVertex(this.getSigma().getShape());
        DoubleVertex gaussianLogProbOutput = Gaussian.logProbOutput(xPlaceholder, muPlaceholder, sigmaPlaceholder);
        AdditionVertex result = gaussianLogProbOutput.plus(LOG_TWO);
        DoubleVertex invalidMask = xPlaceholder.toLessThanMask(0.0);
        SumVertex halfGaussianLogProbOutput = result.setWithMask(invalidMask, Double.NEGATIVE_INFINITY).sum();
        muPlaceholder.setValue(0.0);
        return LogProbGraph.builder().input(this, xPlaceholder).input(this.getMu(), muPlaceholder).input(this.getSigma(), sigmaPlaceholder).logProbOutput(halfGaussianLogProbOutput).build();
    }

    @Override
    public Map<Vertex, DoubleTensor> dLogProb(DoubleTensor value, Set<? extends Vertex> withRespectTo) {
        Map<Vertex, DoubleTensor> dLogProb = super.dLogProb(value, withRespectTo);
        if (value.greaterThanOrEqual(0.0).allTrue()) {
            return dLogProb;
        }
        for (Map.Entry<Vertex, DoubleTensor> entry : dLogProb.entrySet()) {
            DoubleTensor v = entry.getValue();
            dLogProb.put(entry.getKey(), v.setWithMaskInPlace(value.lessThanMask(DoubleTensor.scalar(0.0)), 0.0));
        }
        return dLogProb;
    }

    @Override
    public DoubleTensor sampleWithShape(long[] shape, KeanuRandom random) {
        return (DoubleTensor)super.sampleWithShape(shape, random).absInPlace();
    }
}

