/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.AdditionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DifferenceVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.MultiplicationVertex;

public class SmoothUniform
implements ContinuousDistribution {
    private final DoubleTensor xMin;
    private final DoubleTensor xMax;
    private final double edgeSharpness;

    public static ContinuousDistribution withParameters(DoubleTensor xMin, DoubleTensor xMax, double edgeSharpness) {
        return new SmoothUniform(xMin, xMax, edgeSharpness);
    }

    private SmoothUniform(DoubleTensor xMin, DoubleTensor xMax, double edgeSharpness) {
        this.xMin = xMin;
        this.xMax = xMax;
        this.edgeSharpness = edgeSharpness;
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        DoubleTensor r1 = random.nextDouble(shape);
        DoubleTensor r2 = random.nextDouble(shape);
        DoubleTensor bodyWidth = this.xMax.minus(this.xMin);
        DoubleTensor shoulderWidth = bodyWidth.times(this.edgeSharpness);
        DoubleTensor rScaled = r1.timesInPlace(bodyWidth.plus(shoulderWidth)).plusInPlace(this.xMin.minus(shoulderWidth.div(2.0)));
        DoubleTensor bodyHeight = SmoothUniform.bodyHeight(shoulderWidth, bodyWidth);
        DoubleTensor firstConditional = rScaled.greaterThanOrEqualToMask(this.xMin);
        firstConditional = firstConditional.timesInPlace(rScaled.lessThanOrEqualToMask(this.xMax));
        DoubleTensor inverseFirstConditional = DoubleTensor.ones(firstConditional.getShape()).minusInPlace(firstConditional);
        DoubleTensor secondConditional = rScaled.lessThanMask(this.xMin);
        DoubleTensor spillOnToShoulder = this.xMin.minus(rScaled);
        DoubleTensor shoulderX = shoulderWidth.minus(spillOnToShoulder);
        DoubleTensor shoulderDensity = SmoothUniform.shoulder(shoulderWidth, bodyWidth, shoulderX);
        DoubleTensor acceptProbability = shoulderDensity.div(bodyHeight);
        DoubleTensor secondConditionalNestedTrue = secondConditional.times(r2.lessThanOrEqualToMask(acceptProbability));
        DoubleTensor secondConditionalNestedFalse = secondConditional.timesInPlace(r2.greaterThanMask(acceptProbability));
        DoubleTensor secondConditionalNestedFalseResult = this.xMin.minus(shoulderWidth).plusInPlace(spillOnToShoulder);
        DoubleTensor secondConditionalFalse = rScaled.greaterThanOrEqualToMask(this.xMin);
        spillOnToShoulder = rScaled.minus(this.xMax);
        shoulderX = shoulderWidth.minus(spillOnToShoulder);
        shoulderDensity = SmoothUniform.shoulder(shoulderWidth, bodyWidth, shoulderX);
        acceptProbability = shoulderDensity.divInPlace(bodyHeight);
        DoubleTensor secondConditionalFalseNestedTrue = secondConditionalFalse.times(r2.lessThanOrEqualToMask(acceptProbability));
        DoubleTensor secondConditionalFalseNestedFalse = secondConditionalFalse.timesInPlace(r2.greaterThanMask(acceptProbability));
        DoubleTensor secondConditionalFalseNestedFalseResult = shoulderWidth.plusInPlace(this.xMax).minusInPlace(spillOnToShoulder);
        return firstConditional.timesInPlace(rScaled).plusInPlace(inverseFirstConditional.times(secondConditionalNestedTrue).timesInPlace(rScaled)).plusInPlace(inverseFirstConditional.times(secondConditionalNestedFalse).timesInPlace(secondConditionalNestedFalseResult)).plusInPlace(inverseFirstConditional.times(secondConditionalFalseNestedTrue).timesInPlace(rScaled)).plusInPlace(inverseFirstConditional.timesInPlace(secondConditionalFalseNestedFalse).timesInPlace(secondConditionalFalseNestedFalseResult));
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        DoubleTensor bodyWidth = this.xMax.minus(this.xMin);
        DoubleTensor shoulderWidth = bodyWidth.times(this.edgeSharpness);
        DoubleTensor rightCutoff = this.xMax.plus(shoulderWidth);
        DoubleTensor leftCutoff = this.xMin.minus(shoulderWidth);
        DoubleTensor firstConditional = x.greaterThanOrEqualToMask(this.xMin);
        firstConditional = firstConditional.timesInPlace(x.lessThanOrEqualToMask(this.xMax));
        DoubleTensor firstConditionalResult = SmoothUniform.bodyHeight(shoulderWidth, bodyWidth);
        DoubleTensor secondConditional = x.lessThanMask(this.xMin);
        secondConditional = secondConditional.timesInPlace(x.greaterThanMask(leftCutoff));
        DoubleTensor secondConditionalResult = SmoothUniform.shoulder(shoulderWidth, bodyWidth, x.minus(leftCutoff));
        DoubleTensor thirdConditional = x.greaterThanMask(this.xMax);
        thirdConditional = thirdConditional.timesInPlace(x.lessThanMask(rightCutoff));
        DoubleTensor thirdConditionalResult = SmoothUniform.shoulder(shoulderWidth, bodyWidth, shoulderWidth.minus(x).plusInPlace(this.xMax));
        return (DoubleTensor)firstConditional.timesInPlace(firstConditionalResult).plusInPlace(secondConditional.timesInPlace(secondConditionalResult)).plusInPlace(thirdConditional.timesInPlace(thirdConditionalResult)).logInPlace();
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex xMin, DoublePlaceholderVertex xMax, double edgeSharpness) {
        DifferenceVertex bodyWidth = xMax.minus(xMin);
        MultiplicationVertex shoulderWidth = bodyWidth.times(edgeSharpness);
        AdditionVertex rightCutoff = xMax.plus(shoulderWidth);
        DifferenceVertex leftCutoff = xMin.minus(shoulderWidth);
        MultiplicationVertex firstConditional = x.toGreaterThanOrEqualToMask(xMin).times(x.toLessThanOrEqualToMask(xMax));
        DoubleVertex firstConditionResult = SmoothUniform.bodyHeightVertex(shoulderWidth, bodyWidth);
        MultiplicationVertex secondConditional = x.toLessThanMask(xMin).times(x.toGreaterThanMask(leftCutoff));
        DoubleVertex secondConditionalResult = SmoothUniform.shoulderVertex(shoulderWidth, bodyWidth, x.minus(leftCutoff));
        MultiplicationVertex thirdConditional = x.toGreaterThanMask(xMax).times(x.toLessThanMask(rightCutoff));
        DoubleVertex thirdConditionalResult = SmoothUniform.shoulderVertex(shoulderWidth, bodyWidth, shoulderWidth.minus(x).plus(xMax));
        return firstConditional.times(firstConditionResult).plus(secondConditional.times(secondConditionalResult)).plus(thirdConditional.times(thirdConditionalResult)).log();
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        DoubleTensor bodyWidth = this.xMax.minus(this.xMin);
        DoubleTensor shoulderWidth = bodyWidth.times(this.edgeSharpness);
        DoubleTensor leftCutoff = this.xMin.minus(shoulderWidth);
        DoubleTensor rightCutoff = this.xMax.plus(shoulderWidth);
        DoubleTensor firstConditional = x.lessThanMask(this.xMin);
        firstConditional = firstConditional.timesInPlace(x.greaterThanMask(leftCutoff));
        DoubleTensor firstConditionalResult = SmoothUniform.dShoulder(shoulderWidth, bodyWidth, x.minus(leftCutoff));
        DoubleTensor secondConditional = x.greaterThanMask(this.xMax);
        secondConditional = secondConditional.timesInPlace(x.lessThanMask(rightCutoff));
        DoubleTensor secondConditionalResult = (DoubleTensor)SmoothUniform.dShoulder(shoulderWidth, bodyWidth, shoulderWidth.minus(x).plusInPlace(rightCutoff)).unaryMinusInPlace();
        return new Diffs().put(Diffs.X, firstConditional.timesInPlace(firstConditionalResult).plusInPlace(secondConditional.timesInPlace(secondConditionalResult)));
    }

    private static DoubleTensor shoulder(DoubleTensor Sw, DoubleTensor Bw, DoubleTensor x) {
        DoubleTensor A = SmoothUniform.getCubeCoefficient(Sw, Bw);
        DoubleTensor B = SmoothUniform.getSquareCoefficient(Sw, Bw);
        return x.pow(3.0).timesInPlace(A).plusInPlace(x.pow(2.0).timesInPlace(B));
    }

    private static DoubleVertex shoulderVertex(DoubleVertex Sw, DoubleVertex Bw, DoubleVertex x) {
        DoubleVertex A = SmoothUniform.getCubeCoefficientVertex(Sw, Bw);
        DoubleVertex B = SmoothUniform.getSquareCoefficientVertex(Sw, Bw);
        return x.pow(3.0).times(A).plus(x.pow(2.0).times(B));
    }

    private static DoubleTensor dShoulder(DoubleTensor Sw, DoubleTensor Bw, DoubleTensor x) {
        DoubleTensor A = SmoothUniform.getCubeCoefficient(Sw, Bw);
        DoubleTensor B = SmoothUniform.getSquareCoefficient(Sw, Bw);
        return (DoubleTensor)((Object)((DoubleTensor)((Object)A.timesInPlace(3.0))).timesInPlace(x.pow(2.0)).plusInPlace(B.timesInPlace(x).timesInPlace(2.0)));
    }

    private static DoubleTensor getCubeCoefficient(DoubleTensor Sw, DoubleTensor Bw) {
        return (DoubleTensor)((Object)((DoubleTensor)Sw.pow(3.0).timesInPlace(Sw.plus(Bw)).reciprocalInPlace()).timesInPlace(-2.0));
    }

    private static DoubleVertex getCubeCoefficientVertex(DoubleVertex Sw, DoubleVertex Bw) {
        return Sw.pow(3.0).times(Sw.plus(Bw)).reverseDiv(-2.0);
    }

    private static DoubleTensor getSquareCoefficient(DoubleTensor Sw, DoubleTensor Bw) {
        return (DoubleTensor)((Object)((DoubleTensor)Sw.pow(2.0).timesInPlace(Sw.plus(Bw)).reciprocalInPlace()).timesInPlace(3.0));
    }

    private static DoubleVertex getSquareCoefficientVertex(DoubleVertex Sw, DoubleVertex Bw) {
        return Sw.pow(2.0).times(Sw.plus(Bw)).reverseDiv(3.0);
    }

    private static DoubleTensor bodyHeight(DoubleTensor shoulderWidth, DoubleTensor bodyWidth) {
        return (DoubleTensor)shoulderWidth.plus(bodyWidth).reciprocalInPlace();
    }

    private static DoubleVertex bodyHeightVertex(DoubleVertex shoulderWidth, DoubleVertex bodyWidth) {
        return shoulderWidth.plus(bodyWidth).reverseDiv(1.0);
    }
}

