package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;

/* loaded from: input_file:io/improbable/keanu/distributions/continuous/Pareto.class */
public class Pareto implements ContinuousDistribution {
    private final DoubleTensor location;
    private final DoubleTensor scale;

    public static ContinuousDistribution withParameters(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        return new Pareto(doubleTensor, doubleTensor2);
    }

    private Pareto(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        this.location = doubleTensor;
        this.scale = doubleTensor2;
    }

    @Override // io.improbable.keanu.distributions.ContinuousDistribution
    public Diffs dLogProb(DoubleTensor doubleTensor) {
        DoubleTensor doubleTensor2 = (DoubleTensor) ((DoubleTensor) this.scale.plus2(1.0d).divInPlace(doubleTensor)).unaryMinusInPlace();
        DoubleTensor doubleTensor3 = (DoubleTensor) ((DoubleTensor) this.scale.div(this.location)).broadcast(doubleTensor.getShape());
        return new Diffs().put(Diffs.X, doubleTensor2).put(Diffs.L, doubleTensor3).put(Diffs.S, (DoubleTensor) ((DoubleTensor) this.scale.reciprocal().plusInPlace(this.location.log2())).minusInPlace(doubleTensor.log2()));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        return (DoubleTensor) ((DoubleTensor) ((DoubleTensor) ((DoubleTensor) keanuRandom.nextDouble(jArr).unaryMinusInPlace()).plusInPlace((DoubleTensor) Double.valueOf(1.0d))).powInPlace(this.scale.reciprocal())).reciprocalInPlace().timesInPlace(this.location);
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(DoubleTensor doubleTensor) {
        return checkParamsAreValid() ? setProbToZeroForInvalidX(doubleTensor, (DoubleTensor) ((DoubleTensor) this.scale.log2().plusInPlace((DoubleTensor) this.location.log2().timesInPlace(this.scale))).minusInPlace((DoubleTensor) this.scale.plus2(1.0d).timesInPlace(doubleTensor.log2()))) : DoubleTensor.create(Double.NEGATIVE_INFINITY, doubleTensor.getShape());
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex doublePlaceholderVertex, DoublePlaceholderVertex doublePlaceholderVertex2, DoublePlaceholderVertex doublePlaceholderVertex3) {
        return doublePlaceholderVertex3.log2().plus((DoubleVertex) doublePlaceholderVertex2.log2().times((DoubleVertex) doublePlaceholderVertex3)).minus((DoubleVertex) doublePlaceholderVertex3.plus2(1.0d).times((DoubleVertex) doublePlaceholderVertex.log2())).setWithMask(doublePlaceholderVertex.toGreaterThanMask(doublePlaceholderVertex2).times(doublePlaceholderVertex2.toGreaterThanMask(0.0d)).times(doublePlaceholderVertex3.toGreaterThanMask(0.0d)).unaryMinus().plus2(1.0d), Double.NEGATIVE_INFINITY);
    }

    private boolean checkParamsAreValid() {
        return this.location.greaterThan((DoubleTensor) Double.valueOf(0.0d)).allTrue() && this.scale.greaterThan((DoubleTensor) Double.valueOf(0.0d)).allTrue();
    }

    private DoubleTensor setProbToZeroForInvalidX(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        doubleTensor2.setWithMaskInPlace((DoubleTensor) doubleTensor.lessThanOrEqualToMask(this.location), Double.valueOf(Double.NEGATIVE_INFINITY));
        return doubleTensor2;
    }
}
