package io.improbable.keanu.distributions.continuous;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.ContinuousDistribution;
import io.improbable.keanu.distributions.hyperparam.Diffs;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoublePlaceholderVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DivisionVertex;

/* loaded from: input_file:io/improbable/keanu/distributions/continuous/Logistic.class */
public class Logistic implements ContinuousDistribution {
    private final DoubleTensor mu;
    private final DoubleTensor s;

    public static ContinuousDistribution withParameters(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        return new Logistic(doubleTensor, doubleTensor2);
    }

    private Logistic(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        this.mu = doubleTensor;
        this.s = doubleTensor2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor sample(long[] jArr, KeanuRandom keanuRandom) {
        return (DoubleTensor) ((DoubleTensor) keanuRandom.nextDouble(jArr).reciprocalInPlace().minusInPlace((DoubleTensor) Double.valueOf(1.0d))).logInPlace().timesInPlace((DoubleTensor) this.mu.minus(this.s));
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(DoubleTensor doubleTensor) {
        DoubleTensor doubleTensor2 = (DoubleTensor) ((DoubleTensor) doubleTensor.minus(this.mu)).divInPlace(this.s);
        return (DoubleTensor) ((DoubleTensor) doubleTensor2.plus(this.s.reciprocal().logInPlace())).minusInPlace((DoubleTensor) ((DoubleTensor) doubleTensor2.expInPlace().plusInPlace((DoubleTensor) Double.valueOf(1.0d))).logInPlace().timesInPlace((DoubleTensor) Double.valueOf(2.0d)));
    }

    public static DoubleVertex logProbOutput(DoublePlaceholderVertex doublePlaceholderVertex, DoublePlaceholderVertex doublePlaceholderVertex2, DoublePlaceholderVertex doublePlaceholderVertex3) {
        DivisionVertex div = doublePlaceholderVertex.minus((DoubleVertex) doublePlaceholderVertex2).div((DoubleVertex) doublePlaceholderVertex3);
        return div.plus((DoubleVertex) doublePlaceholderVertex3.reverseDiv2(1.0d).log2()).minus((DoubleVertex) div.exp2().plus2(1.0d).log2().times2(2.0d));
    }

    @Override // io.improbable.keanu.distributions.ContinuousDistribution
    public Diffs dLogProb(DoubleTensor doubleTensor) {
        DoubleTensor expInPlace = ((DoubleTensor) this.mu.div(this.s)).expInPlace();
        DoubleTensor expInPlace2 = ((DoubleTensor) doubleTensor.div(this.s)).expInPlace();
        DoubleTensor doubleTensor2 = (DoubleTensor) expInPlace.plus(expInPlace2);
        DoubleTensor doubleTensor3 = (DoubleTensor) expInPlace.times(this.s);
        DoubleTensor doubleTensor4 = (DoubleTensor) expInPlace2.times(this.s);
        DoubleTensor doubleTensor5 = (DoubleTensor) ((DoubleTensor) expInPlace2.minus(expInPlace)).divInPlace((DoubleTensor) this.s.times(doubleTensor2));
        DoubleTensor doubleTensor6 = (DoubleTensor) ((DoubleTensor) expInPlace.minus(expInPlace2)).divInPlace((DoubleTensor) doubleTensor3.plus(doubleTensor4));
        DoubleTensor doubleTensor7 = (DoubleTensor) ((DoubleTensor) ((DoubleTensor) this.mu.times(expInPlace2)).plusInPlace((DoubleTensor) doubleTensor.times(expInPlace))).plusInPlace((DoubleTensor) this.mu.times((DoubleTensor) expInPlace.unaryMinus()));
        DoubleTensor doubleTensor8 = (DoubleTensor) ((DoubleTensor) doubleTensor3.plus(doubleTensor4)).minusInPlace((DoubleTensor) doubleTensor.times(expInPlace2));
        return new Diffs().put(Diffs.MU, doubleTensor5).put(Diffs.S, (DoubleTensor) ((DoubleTensor) ((DoubleTensor) doubleTensor7.plus(doubleTensor8)).divInPlace((DoubleTensor) this.s.pow2(2.0d).timesInPlace(doubleTensor2))).unaryMinusInPlace()).put(Diffs.X, doubleTensor6);
    }
}
