/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import com.google.common.base.Preconditions;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.AdditionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.LogVertex;

public class Cauchy
implements ContinuousDistribution {
    private static final double NEG_LOG_PI = -Math.log(Math.PI);
    private final DoubleTensor location;
    private final DoubleTensor scale;

    public static ContinuousDistribution withParameters(DoubleTensor location, DoubleTensor scale) {
        return new Cauchy(location, scale);
    }

    private Cauchy(DoubleTensor location, DoubleTensor scale) {
        this.location = location;
        this.scale = scale;
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        Preconditions.checkArgument((boolean)this.scale.greaterThan(0.0).allTrue(), (Object)("scale must be greater than 0. scale: " + this.scale));
        DoubleTensor unityCauchy = random.nextDouble(shape);
        return ((DoubleTensor)((DoubleTensor)((Object)((DoubleTensor)((Object)unityCauchy.minusInPlace(0.5))).timesInPlace(Math.PI))).tanInPlace()).timesInPlace(this.scale).plusInPlace(this.location);
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        DoubleTensor negLnScaleMinusLnPi = (DoubleTensor)((Object)((DoubleTensor)((DoubleTensor)this.scale.log()).unaryMinusInPlace()).plusInPlace(NEG_LOG_PI));
        DoubleTensor xMinusLocationOverScalePow2Plus1 = (DoubleTensor)((Object)((DoubleTensor)x.minus(this.location).divInPlace(this.scale).powInPlace(2.0)).plusInPlace(1.0));
        DoubleTensor lnXMinusLocationOverScalePow2Plus1 = (DoubleTensor)xMinusLocationOverScalePow2Plus1.logInPlace();
        return negLnScaleMinusLnPi.minusInPlace(lnXMinusLocationOverScalePow2Plus1);
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex location, DoublePlaceholderVertex scale) {
        AdditionVertex negLnScaleMinusLnPi = scale.log().unaryMinus().plus(NEG_LOG_PI);
        AdditionVertex xMinusLocationOverScalePow2Plus1 = x.minus(location).div(scale).pow(2.0).plus(1.0);
        LogVertex lnXMinusLocationOverScalePow2Plus1 = xMinusLocationOverScalePow2Plus1.log();
        return negLnScaleMinusLnPi.minus(lnXMinusLocationOverScalePow2Plus1);
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        DoubleTensor xMinusLocation = x.minus(this.location);
        DoubleTensor xMinusLocationPow2 = xMinusLocation.pow(2.0);
        DoubleTensor scalePow2 = this.scale.pow(2.0);
        DoubleTensor locationTimesXTimes2 = (DoubleTensor)((Object)this.location.times(x).timesInPlace(2.0));
        DoubleTensor dLogPdlocation = xMinusLocation.times(2.0).divInPlace(scalePow2.plus(xMinusLocationPow2));
        DoubleTensor dLogPdscale = xMinusLocationPow2.minus(scalePow2).divInPlace(this.scale.times(xMinusLocationPow2.plus(scalePow2)));
        DoubleTensor dLogPdxDenominator = this.location.pow(2.0).minusInPlace(locationTimesXTimes2).plusInPlace(scalePow2).plusInPlace(x.pow(2.0));
        DoubleTensor dLogPdx = xMinusLocation.times(-2.0).divInPlace(dLogPdxDenominator);
        return new Diffs().put(Diffs.L, dLogPdlocation).put(Diffs.S, dLogPdscale).put(Diffs.X, dLogPdx);
    }
}

