/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.weights;

import java.util.Arrays;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.api.rng.distribution.impl.OrthogonalDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

public class WeightInitUtil {
    public static final char DEFAULT_WEIGHT_INIT_ORDER = 'f';

    private WeightInitUtil() {
    }

    @Deprecated
    public static INDArray initWeights(double fanIn, double fanOut, int[] shape, WeightInit initScheme, Distribution dist, INDArray paramView) {
        return WeightInitUtil.initWeights(fanIn, fanOut, ArrayUtil.toLongArray((int[])shape), initScheme, dist, 'f', paramView);
    }

    public static INDArray initWeights(double fanIn, double fanOut, long[] shape, WeightInit initScheme, Distribution dist, INDArray paramView) {
        return WeightInitUtil.initWeights(fanIn, fanOut, shape, initScheme, dist, 'f', paramView);
    }

    @Deprecated
    public static INDArray initWeights(double fanIn, double fanOut, int[] shape, WeightInit initScheme, Distribution dist, char order, INDArray paramView) {
        return WeightInitUtil.initWeights(fanIn, fanOut, ArrayUtil.toLongArray((int[])shape), initScheme, dist, order, paramView);
    }

    public static INDArray initWeights(double fanIn, double fanOut, long[] shape, WeightInit initScheme, Distribution dist, char order, INDArray paramView) {
        switch (initScheme) {
            case DISTRIBUTION: {
                if (dist instanceof OrthogonalDistribution) {
                    dist.sample(paramView.reshape(order, shape));
                    break;
                }
                dist.sample(paramView);
                break;
            }
            case RELU: {
                Nd4j.randn((INDArray)paramView).muli((Number)FastMath.sqrt((double)(2.0 / fanIn)));
                break;
            }
            case RELU_UNIFORM: {
                double u = Math.sqrt(6.0 / fanIn);
                Nd4j.rand((INDArray)paramView, (Distribution)Nd4j.getDistributions().createUniform(-u, u));
                break;
            }
            case SIGMOID_UNIFORM: {
                double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut));
                Nd4j.rand((INDArray)paramView, (Distribution)Nd4j.getDistributions().createUniform(-r, r));
                break;
            }
            case UNIFORM: {
                double a = 1.0 / Math.sqrt(fanIn);
                Nd4j.rand((INDArray)paramView, (Distribution)Nd4j.getDistributions().createUniform(-a, a));
                break;
            }
            case LECUN_UNIFORM: {
                double b = 3.0 / Math.sqrt(fanIn);
                Nd4j.rand((INDArray)paramView, (Distribution)Nd4j.getDistributions().createUniform(-b, b));
                break;
            }
            case XAVIER: {
                Nd4j.randn((INDArray)paramView).muli((Number)FastMath.sqrt((double)(2.0 / (fanIn + fanOut))));
                break;
            }
            case XAVIER_UNIFORM: {
                double s = Math.sqrt(6.0) / Math.sqrt(fanIn + fanOut);
                Nd4j.rand((INDArray)paramView, (Distribution)Nd4j.getDistributions().createUniform(-s, s));
                break;
            }
            case LECUN_NORMAL: 
            case NORMAL: 
            case XAVIER_FAN_IN: {
                Nd4j.randn((INDArray)paramView).divi((Number)FastMath.sqrt((double)fanIn));
                break;
            }
            case XAVIER_LEGACY: {
                Nd4j.randn((INDArray)paramView).divi((Number)FastMath.sqrt((double)(shape[0] + shape[1])));
                break;
            }
            case ZERO: {
                paramView.assign((Number)0.0);
                break;
            }
            case ONES: {
                paramView.assign((Number)1.0);
                break;
            }
            case IDENTITY: {
                if (shape.length != 2 || shape[0] != shape[1]) {
                    throw new IllegalStateException("Cannot use IDENTITY init with parameters of shape " + Arrays.toString(shape) + ": weights must be a square matrix for identity");
                }
                INDArray ret = order == Nd4j.order().charValue() ? Nd4j.eye((long)shape[0]) : Nd4j.createUninitialized((long[])shape, (char)order).assign(Nd4j.eye((long)shape[0]));
                INDArray flat = Nd4j.toFlattened((char)order, (INDArray[])new INDArray[]{ret});
                paramView.assign(flat);
                break;
            }
            case VAR_SCALING_NORMAL_FAN_IN: {
                Nd4j.exec((Op)new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanIn)));
                break;
            }
            case VAR_SCALING_NORMAL_FAN_OUT: {
                Nd4j.exec((Op)new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanOut)));
                break;
            }
            case VAR_SCALING_NORMAL_FAN_AVG: {
                Nd4j.exec((Op)new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(2.0 / (fanIn + fanOut))));
                break;
            }
            case VAR_SCALING_UNIFORM_FAN_IN: {
                double scalingFanIn = 3.0 / Math.sqrt(fanIn);
                Nd4j.rand((INDArray)paramView, (Distribution)Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn));
                break;
            }
            case VAR_SCALING_UNIFORM_FAN_OUT: {
                double scalingFanOut = 3.0 / Math.sqrt(fanOut);
                Nd4j.rand((INDArray)paramView, (Distribution)Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut));
                break;
            }
            case VAR_SCALING_UNIFORM_FAN_AVG: {
                double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2.0);
                Nd4j.rand((INDArray)paramView, (Distribution)Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg));
                break;
            }
            default: {
                throw new IllegalStateException("Illegal weight init value: " + (Object)((Object)initScheme));
            }
        }
        return paramView.reshape(order, shape);
    }

    public static INDArray reshapeWeights(int[] shape, INDArray paramsView) {
        return WeightInitUtil.reshapeWeights(shape, paramsView, 'f');
    }

    public static INDArray reshapeWeights(long[] shape, INDArray paramsView) {
        return WeightInitUtil.reshapeWeights(shape, paramsView, 'f');
    }

    public static INDArray reshapeWeights(int[] shape, INDArray paramsView, char flatteningOrder) {
        return paramsView.reshape(flatteningOrder, shape);
    }

    public static INDArray reshapeWeights(long[] shape, INDArray paramsView, char flatteningOrder) {
        return paramsView.reshape(flatteningOrder, shape);
    }
}

