/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.initializer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.initializer.Initializer;

public class XavierInitializer
implements Initializer {
    private RandomType randomType;
    private FactorType factorType;
    private double magnitude;

    public XavierInitializer(RandomType randomType, FactorType factorType, double magnitude) {
        this.randomType = randomType;
        this.factorType = factorType;
        this.magnitude = magnitude;
    }

    public XavierInitializer() {
        this(RandomType.UNIFORM, FactorType.AVG, 3.0);
    }

    @Override
    public NDArray initialize(NDManager manager, Shape shape, DataType dataType) {
        double factor;
        double hwScale;
        long dimension = shape.dimension();
        if (dimension < 2L) {
            throw new IllegalArgumentException("XavierInitializer cannot be applied to Shape with dimension: " + dimension + ", it requires shape to be at least 2D.");
        }
        if (dimension == 2L) {
            hwScale = 1.0;
        } else {
            Shape shapeSliced = shape.slice(2);
            hwScale = shapeSliced.size();
        }
        double fanIn = (double)shape.get(1) * hwScale;
        double fanOut = (double)shape.head() * hwScale;
        switch (this.factorType) {
            case AVG: {
                factor = (fanIn + fanOut) / 2.0;
                break;
            }
            case IN: {
                factor = fanIn;
                break;
            }
            case OUT: {
                factor = fanOut;
                break;
            }
            default: {
                throw new IllegalArgumentException("Invalid factor type, valid types are: avg, in, out");
            }
        }
        if (factor == 0.0) {
            throw new IllegalStateException("Xavier initializer factor is 0, please check your input shape.");
        }
        double scale = StrictMath.sqrt(this.magnitude / factor);
        switch (this.randomType) {
            case UNIFORM: {
                return manager.randomUniform(-scale, scale, shape, dataType, manager.getDevice());
            }
            case GAUSSIAN: {
                return manager.randomNormal(0, scale, shape, dataType, manager.getDevice());
            }
        }
        throw new IllegalArgumentException("Invalid randomType");
    }

    public static enum FactorType {
        AVG,
        IN,
        OUT;

    }

    public static enum RandomType {
        UNIFORM,
        GAUSSIAN;

    }
}

