/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.AdditionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DifferenceVertex;

public class Pareto
implements ContinuousDistribution {
    private final DoubleTensor location;
    private final DoubleTensor scale;

    public static ContinuousDistribution withParameters(DoubleTensor location, DoubleTensor scale) {
        return new Pareto(location, scale);
    }

    private Pareto(DoubleTensor location, DoubleTensor scale) {
        this.location = location;
        this.scale = scale;
    }

    @Override
    public Diffs dLogProb(DoubleTensor x) {
        DoubleTensor dLogPdx = (DoubleTensor)this.scale.plus(1.0).divInPlace(x).unaryMinusInPlace();
        DoubleTensor dLogPdLocation = (DoubleTensor)this.scale.div(this.location).broadcast(x.getShape());
        DoubleTensor dLogPdScale = (DoubleTensor)((DoubleTensor)((DoubleTensor)this.scale.reciprocal()).plusInPlace(this.location.log())).minusInPlace(x.log());
        return new Diffs().put(Diffs.X, dLogPdx).put(Diffs.L, dLogPdLocation).put(Diffs.S, dLogPdScale);
    }

    @Override
    public DoubleTensor sample(long[] shape, KeanuRandom random) {
        return ((DoubleTensor)((DoubleTensor)((DoubleTensor)((Object)((DoubleTensor)random.nextDouble(shape).unaryMinusInPlace()).plusInPlace(1.0))).powInPlace(this.scale.reciprocal())).reciprocalInPlace()).timesInPlace(this.location);
    }

    @Override
    public DoubleTensor logProb(DoubleTensor x) {
        if (this.checkParamsAreValid()) {
            DoubleTensor result = (DoubleTensor)((DoubleTensor)this.scale.log()).plusInPlace(((DoubleTensor)this.location.log()).timesInPlace(this.scale)).minusInPlace(this.scale.plus(1.0).timesInPlace(x.log()));
            return this.setProbToZeroForInvalidX(x, result);
        }
        return DoubleTensor.create(Double.NEGATIVE_INFINITY, x.getShape());
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex location, DoublePlaceholderVertex scale) {
        AdditionVertex invalidXMask = x.toGreaterThanMask(location).times(location.toGreaterThanMask(0.0)).times(scale.toGreaterThanMask(0.0)).unaryMinus().plus(1.0);
        DifferenceVertex ifValid = scale.log().plus(location.log().times(scale)).minus(scale.plus(1.0).times(x.log()));
        return ifValid.setWithMask((DoubleVertex)invalidXMask, Double.NEGATIVE_INFINITY);
    }

    private boolean checkParamsAreValid() {
        return this.location.greaterThan(0.0).allTrue() && this.scale.greaterThan(0.0).allTrue();
    }

    private DoubleTensor setProbToZeroForInvalidX(DoubleTensor x, DoubleTensor result) {
        DoubleTensor invalids = x.lessThanOrEqualToMask(this.location);
        result.setWithMaskInPlace(invalids, Double.NEGATIVE_INFINITY);
        return result;
    }
}

