/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.continuous.SmoothUniform;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.LogProbGraph;
import io.improbable.keanu.vertices.LogProbGraphSupplier;
import io.improbable.keanu.vertices.SamplableWithManyScalars;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import io.improbable.keanu.vertices.dbl.probabilistic.ProbabilisticDouble;
import java.util.Collections;
import java.util.Map;
import java.util.Set;

public class SmoothUniformVertex
extends DoubleVertex
implements Differentiable,
ProbabilisticDouble,
SamplableWithManyScalars<DoubleTensor>,
LogProbGraphSupplier {
    private static final double DEFAULT_EDGE_SHARPNESS = 0.01;
    private final DoubleVertex xMin;
    private final DoubleVertex xMax;
    private final double edgeSharpness;
    private static final String X_MIN_NAME = "xMin";
    private static final String X_MAX_NAME = "xMax";

    public SmoothUniformVertex(@LoadShape long[] tensorShape, @LoadVertexParam(value="xMin") DoubleVertex xMin, @LoadVertexParam(value="xMax") DoubleVertex xMax) {
        this(tensorShape, xMin, xMax, 0.01);
    }

    public SmoothUniformVertex(long[] tensorShape, DoubleVertex xMin, DoubleVertex xMax, double edgeSharpness) {
        super(tensorShape);
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(tensorShape, xMin.getShape(), xMax.getShape());
        this.xMin = xMin;
        this.xMax = xMax;
        this.edgeSharpness = edgeSharpness;
        this.setParents(xMin, xMax);
    }

    public SmoothUniformVertex(DoubleVertex xMin, DoubleVertex xMax, double edgeSharpness) {
        this(TensorShapeValidation.checkHasOneNonLengthOneShapeOrAllLengthOne(xMin.getShape(), xMax.getShape()), xMin, xMax, edgeSharpness);
    }

    public SmoothUniformVertex(DoubleVertex xMin, double xMax, double edgeSharpness) {
        this(xMin, (DoubleVertex)new ConstantDoubleVertex(xMax), edgeSharpness);
    }

    public SmoothUniformVertex(double xMin, DoubleVertex xMax, double edgeSharpness) {
        this((DoubleVertex)new ConstantDoubleVertex(xMin), xMax, edgeSharpness);
    }

    public SmoothUniformVertex(double xMin, double xMax, double edgeSharpness) {
        this((DoubleVertex)new ConstantDoubleVertex(xMin), (DoubleVertex)new ConstantDoubleVertex(xMax), edgeSharpness);
    }

    @ExportVertexToPythonBindings
    public SmoothUniformVertex(DoubleVertex xMin, DoubleVertex xMax) {
        this(xMin, xMax, 0.01);
    }

    public SmoothUniformVertex(DoubleVertex xMin, double xMax) {
        this(xMin, (DoubleVertex)new ConstantDoubleVertex(xMax), 0.01);
    }

    public SmoothUniformVertex(double xMin, DoubleVertex xMax) {
        this((DoubleVertex)new ConstantDoubleVertex(xMin), xMax, 0.01);
    }

    public SmoothUniformVertex(double xMin, double xMax) {
        this((DoubleVertex)new ConstantDoubleVertex(xMin), (DoubleVertex)new ConstantDoubleVertex(xMax), 0.01);
    }

    public SmoothUniformVertex(long[] tensorShape, DoubleVertex xMin, double xMax, double edgeSharpness) {
        this(tensorShape, xMin, (DoubleVertex)new ConstantDoubleVertex(xMax), edgeSharpness);
    }

    public SmoothUniformVertex(long[] tensorShape, double xMin, DoubleVertex xMax, double edgeSharpness) {
        this(tensorShape, (DoubleVertex)new ConstantDoubleVertex(xMin), xMax, edgeSharpness);
    }

    public SmoothUniformVertex(long[] tensorShape, double xMin, double xMax, double edgeSharpness) {
        this(tensorShape, (DoubleVertex)new ConstantDoubleVertex(xMin), (DoubleVertex)new ConstantDoubleVertex(xMax), edgeSharpness);
    }

    public SmoothUniformVertex(long[] tensorShape, DoubleVertex xMin, double xMax) {
        this(tensorShape, xMin, (DoubleVertex)new ConstantDoubleVertex(xMax), 0.01);
    }

    public SmoothUniformVertex(long[] tensorShape, double xMin, DoubleVertex xMax) {
        this(tensorShape, (DoubleVertex)new ConstantDoubleVertex(xMin), xMax, 0.01);
    }

    public SmoothUniformVertex(long[] tensorShape, double xMin, double xMax) {
        this(tensorShape, (DoubleVertex)new ConstantDoubleVertex(xMin), (DoubleVertex)new ConstantDoubleVertex(xMax), 0.01);
    }

    @SaveVertexParam(value="xMin")
    public DoubleVertex getXMin() {
        return this.xMin;
    }

    @SaveVertexParam(value="xMax")
    public DoubleVertex getXMax() {
        return this.xMax;
    }

    public double getEdgeSharpness() {
        return this.edgeSharpness;
    }

    @Override
    public double logProb(DoubleTensor value) {
        DoubleTensor min = (DoubleTensor)this.xMin.getValue();
        DoubleTensor max = (DoubleTensor)this.xMax.getValue();
        DoubleTensor density = SmoothUniform.withParameters(min, max, this.edgeSharpness).logProb(value);
        return (Double)density.sum();
    }

    @Override
    public LogProbGraph logProbGraph() {
        DoublePlaceholderVertex xPlaceholder = new DoublePlaceholderVertex(this.getShape());
        DoublePlaceholderVertex xMinPlaceholder = new DoublePlaceholderVertex(this.xMin.getShape());
        DoublePlaceholderVertex xMaxPlaceholder = new DoublePlaceholderVertex(this.xMax.getShape());
        return LogProbGraph.builder().input(this, xPlaceholder).input(this.xMin, xMinPlaceholder).input(this.xMax, xMaxPlaceholder).logProbOutput(SmoothUniform.logProbOutput(xPlaceholder, xMinPlaceholder, xMaxPlaceholder, this.edgeSharpness)).build();
    }

    @Override
    public Map<Vertex, DoubleTensor> dLogProb(DoubleTensor value, Set<? extends Vertex> withRespectTo) {
        if (withRespectTo.contains(this)) {
            DoubleTensor min = (DoubleTensor)this.xMin.getValue();
            DoubleTensor max = (DoubleTensor)this.xMax.getValue();
            ContinuousDistribution distribution = SmoothUniform.withParameters(min, max, this.edgeSharpness);
            DoubleTensor dPdx = distribution.dLogProb(value).get(Diffs.X).getValue();
            DoubleTensor density = distribution.logProb(value);
            DoubleTensor dLogPdx = dPdx.divInPlace(density);
            return Collections.singletonMap(this, dLogPdx);
        }
        return Collections.emptyMap();
    }

    @Override
    public DoubleTensor sampleWithShape(long[] shape, KeanuRandom random) {
        return (DoubleTensor)SmoothUniform.withParameters((DoubleTensor)this.xMin.getValue(), (DoubleTensor)this.xMax.getValue(), this.edgeSharpness).sample(shape, random);
    }
}

