/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.rng.distribution.impl;

import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.OutOfRangeException;
import org.apache.commons.math3.exception.util.Localizable;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.special.Erf;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.rng.distribution.BaseDistribution;
import org.nd4j.linalg.factory.Nd4j;

public class NormalDistribution
extends BaseDistribution {
    public static final double DEFAULT_INVERSE_ABSOLUTE_ACCURACY = 1.0E-9;
    private static final long serialVersionUID = 8589540077390120676L;
    private static final double SQRT2PI = FastMath.sqrt((double)(Math.PI * 2));
    private static final double SQRT2 = FastMath.sqrt((double)2.0);
    private final double standardDeviation;
    private double mean;
    private INDArray means;
    private double solverAbsoluteAccuracy;

    public NormalDistribution(Random rng, double standardDeviation, INDArray means) {
        super(rng);
        this.standardDeviation = standardDeviation;
        this.means = means;
    }

    public NormalDistribution(double standardDeviation, INDArray means) {
        this.standardDeviation = standardDeviation;
        this.means = means;
    }

    public NormalDistribution() {
        this(0.0, 1.0);
    }

    public NormalDistribution(double mean, double sd) throws NotStrictlyPositiveException {
        this(mean, sd, 1.0E-9);
    }

    public NormalDistribution(double mean, double sd, double inverseCumAccuracy) throws NotStrictlyPositiveException {
        this(Nd4j.getRandom(), mean, sd, inverseCumAccuracy);
    }

    public NormalDistribution(Random rng, double mean, double sd, double inverseCumAccuracy) throws NotStrictlyPositiveException {
        super(rng);
        if (sd <= 0.0) {
            throw new NotStrictlyPositiveException((Localizable)LocalizedFormats.STANDARD_DEVIATION, (Number)sd);
        }
        this.mean = mean;
        this.standardDeviation = sd;
        this.solverAbsoluteAccuracy = inverseCumAccuracy;
    }

    public NormalDistribution(INDArray mean, double std) {
        this.means = mean;
        this.standardDeviation = std;
        this.random = Nd4j.getRandom();
    }

    public double getMean() {
        return this.mean;
    }

    public double getStandardDeviation() {
        return this.standardDeviation;
    }

    @Override
    public double density(double x) {
        if (this.means != null) {
            throw new IllegalStateException("Unable to sample from more than one mean");
        }
        double x0 = x - this.mean;
        double x1 = x0 / this.standardDeviation;
        return FastMath.exp((double)(-0.5 * x1 * x1)) / (this.standardDeviation * SQRT2PI);
    }

    @Override
    public double cumulativeProbability(double x) {
        if (this.means != null) {
            throw new IllegalStateException("Unable to sample from more than one mean");
        }
        double dev = x - this.mean;
        if (FastMath.abs((double)dev) > 40.0 * this.standardDeviation) {
            return dev < 0.0 ? 0.0 : 1.0;
        }
        return 0.5 * (1.0 + Erf.erf((double)(dev / (this.standardDeviation * SQRT2))));
    }

    @Override
    public double inverseCumulativeProbability(double p) throws OutOfRangeException {
        if (p < 0.0 || p > 1.0) {
            throw new OutOfRangeException((Number)p, (Number)0, (Number)1);
        }
        if (this.means != null) {
            throw new IllegalStateException("Unable to sample from more than one mean");
        }
        return this.mean + this.standardDeviation * SQRT2 * Erf.erfInv((double)(2.0 * p - 1.0));
    }

    @Override
    @Deprecated
    public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException {
        return this.probability(x0, x1);
    }

    @Override
    public double probability(double x0, double x1) throws NumberIsTooLargeException {
        if (x0 > x1) {
            throw new NumberIsTooLargeException((Localizable)LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT, (Number)x0, (Number)x1, true);
        }
        double denom = this.standardDeviation * SQRT2;
        double v0 = (x0 - this.mean) / denom;
        double v1 = (x1 - this.mean) / denom;
        return 0.5 * Erf.erf((double)v0, (double)v1);
    }

    @Override
    protected double getSolverAbsoluteAccuracy() {
        return this.solverAbsoluteAccuracy;
    }

    @Override
    public double getNumericalMean() {
        return this.getMean();
    }

    @Override
    public double getNumericalVariance() {
        double s = this.getStandardDeviation();
        return s * s;
    }

    @Override
    public double getSupportLowerBound() {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double getSupportUpperBound() {
        return Double.POSITIVE_INFINITY;
    }

    @Override
    public boolean isSupportLowerBoundInclusive() {
        return false;
    }

    @Override
    public boolean isSupportUpperBoundInclusive() {
        return false;
    }

    @Override
    public boolean isSupportConnected() {
        return true;
    }

    @Override
    public double sample() {
        if (this.means != null) {
            throw new IllegalStateException("Unable to sample from more than one mean");
        }
        return this.standardDeviation * this.random.nextGaussian() + this.mean;
    }

    @Override
    public INDArray sample(int[] shape) {
        INDArray ret = Nd4j.create(shape);
        NdIndexIterator idxIter = new NdIndexIterator(shape);
        int len = ret.length();
        if (this.means != null) {
            for (int i = 0; i < len; ++i) {
                int[] idx = (int[])idxIter.next();
                ret.putScalar(idx, this.standardDeviation * this.random.nextGaussian() + this.means.getDouble(idx));
            }
        } else {
            for (int i = 0; i < len; ++i) {
                ret.putScalar((int[])idxIter.next(), this.standardDeviation * this.random.nextGaussian() + this.mean);
            }
        }
        return ret;
    }
}

