/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.sampling;

import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.MathUtils;

public class Sampling {
    public static INDArray normal(RandomGenerator rng, INDArray mean, INDArray sigma) {
        INDArray iter = mean.reshape(new int[]{1, mean.length()}).dup();
        INDArray sigmaLinear = sigma.ravel();
        for (int i = 0; i < iter.length(); ++i) {
            NormalDistribution reals = new NormalDistribution(rng, (double)mean.get(i), FastMath.sqrt((double)sigmaLinear.get(i)), 1.0E-9);
            iter.putScalar(i, (Number)reals.sample());
        }
        return iter.reshape(mean.shape());
    }

    public static INDArray normal(RandomGenerator rng, INDArray mean, double sigma) {
        INDArray modify = Nd4j.create(mean.shape());
        INDArray iter = mean.linearView();
        INDArray linearModify = modify.linearView();
        double sqrt = FastMath.sqrt((double)sigma);
        for (int i = 0; i < iter.length(); ++i) {
            double curr = iter.get(i);
            NormalDistribution reals = new NormalDistribution(rng, curr, sqrt, 1.0E-9);
            linearModify.putScalar(i, (Number)reals.sample());
        }
        return modify;
    }

    public static INDArray binomial(INDArray p, int n, RandomGenerator rng) {
        INDArray p2 = p.dup();
        INDArray p2Linear = p2.linearView();
        for (int i = 0; i < p2.length(); ++i) {
            p2Linear.putScalar(i, (Number)MathUtils.binomial(rng, n, p2Linear.get(i)));
        }
        return p2;
    }
}

