/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import com.google.common.base.Preconditions;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DifferenceVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DivisionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.MultiplicationVertex;

public class Triangular
implements ContinuousDistribution {
    private final DoubleTensor xMin;
    private final DoubleTensor xMax;
    private final DoubleTensor c;

    public static ContinuousDistribution withParameters(DoubleTensor xMin, DoubleTensor xMax, DoubleTensor c) {
        return new Triangular(xMin, xMax, c);
    }

    private Triangular(DoubleTensor xMin, DoubleTensor xMax, DoubleTensor c) {
        this.xMin = xMin;
        this.xMax = xMax;
        this.c = c;
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        Preconditions.checkArgument((this.c.greaterThanOrEqual(this.xMin).allTrue() && this.c.lessThanOrEqual(this.xMax).allTrue() ? 1 : 0) != 0, (Object)("center must be between xMin and xMax. c: " + this.c + " xMin: " + this.xMin + " xMax: " + this.xMax));
        DoubleTensor p = random.nextDouble(shape);
        DoubleTensor q = (DoubleTensor)((Object)((DoubleTensor)p.unaryMinus()).plusInPlace(1.0));
        DoubleTensor range = this.xMax.minus(this.xMin);
        DoubleTensor conditional = this.c.minus(this.xMin).divInPlace(this.xMax.minus(this.xMin));
        DoubleTensor lessThan = (DoubleTensor)this.xMin.plus(range.times(this.c.minus(this.xMin).timesInPlace(p)).sqrtInPlace());
        DoubleTensor greaterThan = (DoubleTensor)this.xMax.minus(range.timesInPlace(this.xMax.minus(this.c).timesInPlace(q)).sqrtInPlace());
        DoubleTensor lessThanMask = p.lessThanOrEqualToMask(conditional);
        DoubleTensor greaterThanMask = p.greaterThanMask(conditional);
        return lessThan.timesInPlace(lessThanMask).plusInPlace(greaterThan.timesInPlace(greaterThanMask));
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        DoubleTensor range = this.xMax.minus(this.xMin);
        DoubleTensor conditionalFirstHalf = x.greaterThanMask(this.xMin);
        DoubleTensor conditionalSecondHalf = x.lessThanMask(this.c);
        DoubleTensor conditionalAnd = conditionalFirstHalf.timesInPlace(conditionalSecondHalf);
        DoubleTensor conditionalResult = ((DoubleTensor)((Object)((DoubleTensor)range.reciprocal()).timesInPlace(2.0))).timesInPlace(x.minus(this.xMin)).divInPlace(this.c.minus(this.xMin));
        DoubleTensor elseIfConditionalFirstHalf = x.greaterThanMask(this.c);
        DoubleTensor elseIfConditionalSecondHalf = x.lessThanOrEqualToMask(this.xMax);
        DoubleTensor elseIfConditionalAnd = elseIfConditionalFirstHalf.timesInPlace(elseIfConditionalSecondHalf);
        DoubleTensor elseIfConditionalResult = ((DoubleTensor)((Object)((DoubleTensor)range.reciprocalInPlace()).timesInPlace(2.0))).timesInPlace(this.xMax.minus(x)).divInPlace(this.xMax.minus(this.c));
        return (DoubleTensor)conditionalResult.timesInPlace(conditionalAnd).plusInPlace(elseIfConditionalResult.timesInPlace(elseIfConditionalAnd)).logInPlace();
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex xMin, DoublePlaceholderVertex xMax, DoublePlaceholderVertex c) {
        DifferenceVertex range = xMax.minus(xMin);
        DoubleVertex conditionalFirstHalf = x.toGreaterThanMask(xMin);
        DoubleVertex conditionalSecondHalf = x.toLessThanMask(c);
        MultiplicationVertex conditionalAnd = conditionalFirstHalf.times(conditionalSecondHalf);
        DivisionVertex conditionalResult = range.reverseDiv(1.0).times(2.0).times(x.minus(xMin)).div(c.minus(xMin));
        DoubleVertex elseIfConditionalFirstHalf = x.toGreaterThanMask(c);
        DoubleVertex elseIfConditionalSecondHalf = x.toLessThanOrEqualToMask(xMax);
        MultiplicationVertex elseIfConditionalAnd = elseIfConditionalFirstHalf.times(elseIfConditionalSecondHalf);
        DivisionVertex elseIfConditionalResult = range.reverseDiv(1.0).times(2.0).times(xMax.minus(x)).div(xMax.minus(c));
        return conditionalResult.times(conditionalAnd).plus(elseIfConditionalResult.times(elseIfConditionalAnd)).log();
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        throw new UnsupportedOperationException();
    }
}

