/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.ocnn;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class OCNNParamInitializer
extends DefaultParamInitializer {
    private static final OCNNParamInitializer INSTANCE = new OCNNParamInitializer();
    public static final String NU_KEY = "nu";
    public static final String K_KEY = "k";
    public static final String V_KEY = "v";
    public static final String W_KEY = "w";
    public static final String R_KEY = "r";
    private static final List<String> WEIGHT_KEYS = Arrays.asList("w", "v", "r");
    private static final List<String> PARAM_KEYS = Arrays.asList("w", "v", "r");

    public static OCNNParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public int numParams(NeuralNetConfiguration conf) {
        return this.numParams(conf.getLayer());
    }

    @Override
    public int numParams(Layer layer) {
        int hiddenLayer;
        OCNNOutputLayer ocnnOutputLayer = (OCNNOutputLayer)layer;
        int nIn = ocnnOutputLayer.getNIn();
        int firstLayerWeightLength = hiddenLayer = ocnnOutputLayer.getHiddenSize();
        int secondLayerLength = nIn * hiddenLayer;
        int rLength = 1;
        return firstLayerWeightLength + secondLayerLength + rLength;
    }

    @Override
    public List<String> paramKeys(Layer layer) {
        return PARAM_KEYS;
    }

    @Override
    public List<String> weightKeys(Layer layer) {
        return WEIGHT_KEYS;
    }

    @Override
    public List<String> biasKeys(Layer layer) {
        return Collections.emptyList();
    }

    @Override
    public boolean isWeightParam(Layer layer, String key) {
        return WEIGHT_KEYS.contains(key);
    }

    @Override
    public boolean isBiasParam(Layer layer, String key) {
        return false;
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        int hiddenLayer;
        OCNNOutputLayer ocnnOutputLayer = (OCNNOutputLayer)conf.getLayer();
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        int nIn = ocnnOutputLayer.getNIn();
        int firstLayerWeightLength = hiddenLayer = ocnnOutputLayer.getHiddenSize();
        int secondLayerLength = nIn * hiddenLayer;
        int rLength = 1;
        INDArray weightView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)firstLayerWeightLength)}).reshape(1, hiddenLayer);
        INDArray weightsTwoView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)firstLayerWeightLength, (int)(firstLayerWeightLength + secondLayerLength))}).reshape('f', nIn, hiddenLayer);
        INDArray rView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.point((int)(paramsView.length() - rLength))});
        INDArray paramViewPut = this.createWeightMatrix(conf, weightView, initializeParams);
        params.put(W_KEY, paramViewPut);
        conf.addVariable(W_KEY);
        INDArray paramIvewPutTwo = this.createWeightMatrix(conf, weightsTwoView, initializeParams);
        params.put(V_KEY, paramIvewPutTwo);
        conf.addVariable(V_KEY);
        INDArray rViewPut = this.createWeightMatrix(conf, rView, initializeParams);
        params.put(R_KEY, rViewPut);
        conf.addVariable(R_KEY);
        return params;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        int hiddenLayer;
        OCNNOutputLayer ocnnOutputLayer = (OCNNOutputLayer)conf.getLayer();
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        int nIn = ocnnOutputLayer.getNIn();
        int firstLayerWeightLength = hiddenLayer = ocnnOutputLayer.getHiddenSize();
        int secondLayerLength = nIn * hiddenLayer;
        INDArray weightView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)firstLayerWeightLength)}).reshape('f', 1, hiddenLayer);
        INDArray vView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)firstLayerWeightLength, (int)(firstLayerWeightLength + secondLayerLength))}).reshape('f', nIn, hiddenLayer);
        params.put(W_KEY, weightView);
        params.put(V_KEY, vView);
        params.put(R_KEY, gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.point((int)(gradientView.length() - 1))}));
        return params;
    }

    @Override
    protected INDArray createWeightMatrix(NeuralNetConfiguration configuration, INDArray weightParamView, boolean initializeParameters) {
        OCNNOutputLayer ocnnOutputLayer = (OCNNOutputLayer)configuration.getLayer();
        WeightInit weightInit = ocnnOutputLayer.getWeightInit();
        Distribution dist = Distributions.createDistribution(ocnnOutputLayer.getDist());
        if (initializeParameters) {
            INDArray ret = WeightInitUtil.initWeights(weightParamView.size(0), weightParamView.size(1), weightParamView.shape(), weightInit, dist, weightParamView);
            return ret;
        }
        return WeightInitUtil.reshapeWeights(weightParamView.shape(), weightParamView);
    }
}

