/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.ocnn;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class OCNNParamInitializer
extends DefaultParamInitializer {
    private static final OCNNParamInitializer INSTANCE = new OCNNParamInitializer();
    public static final String NU_KEY = "nu";
    public static final String K_KEY = "k";
    public static final String V_KEY = "v";
    public static final String W_KEY = "w";
    public static final String R_KEY = "r";
    private static final List<String> WEIGHT_KEYS = Arrays.asList("w", "v", "r");
    private static final List<String> PARAM_KEYS = Arrays.asList("w", "v", "r");

    public static OCNNParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public long numParams(NeuralNetConfiguration conf) {
        return this.numParams(conf.getLayer());
    }

    @Override
    public long numParams(Layer layer) {
        int hiddenLayer;
        OCNNOutputLayer ocnnOutputLayer = (OCNNOutputLayer)layer;
        long nIn = ocnnOutputLayer.getNIn();
        int firstLayerWeightLength = hiddenLayer = ocnnOutputLayer.getHiddenSize();
        long secondLayerLength = nIn * (long)hiddenLayer;
        boolean rLength = true;
        return (long)firstLayerWeightLength + secondLayerLength + 1L;
    }

    @Override
    public List<String> paramKeys(Layer layer) {
        return PARAM_KEYS;
    }

    @Override
    public List<String> weightKeys(Layer layer) {
        return WEIGHT_KEYS;
    }

    @Override
    public List<String> biasKeys(Layer layer) {
        return Collections.emptyList();
    }

    @Override
    public boolean isWeightParam(Layer layer, String key) {
        return WEIGHT_KEYS.contains(key);
    }

    @Override
    public boolean isBiasParam(Layer layer, String key) {
        return false;
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        OCNNOutputLayer ocnnOutputLayer = (OCNNOutputLayer)conf.getLayer();
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        long nIn = ocnnOutputLayer.getNIn();
        int hiddenLayer = ocnnOutputLayer.getHiddenSize();
        Preconditions.checkState((hiddenLayer > 0 ? 1 : 0) != 0, (String)"OCNNOutputLayer hidden layer state: must be non-zero.");
        int firstLayerWeightLength = hiddenLayer;
        long secondLayerLength = nIn * (long)hiddenLayer;
        int rLength = 1;
        INDArray weightView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)firstLayerWeightLength)}).reshape(1L, (long)hiddenLayer);
        INDArray weightsTwoView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)firstLayerWeightLength, (long)((long)firstLayerWeightLength + secondLayerLength))}).reshape('f', new long[]{nIn, hiddenLayer});
        INDArray rView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((long)(paramsView.length() - (long)rLength))});
        INDArray paramViewPut = this.createWeightMatrix(conf, weightView, initializeParams);
        params.put(W_KEY, paramViewPut);
        conf.addVariable(W_KEY);
        INDArray paramIvewPutTwo = this.createWeightMatrix(conf, weightsTwoView, initializeParams);
        params.put(V_KEY, paramIvewPutTwo);
        conf.addVariable(V_KEY);
        INDArray rViewPut = this.createWeightMatrix(conf, rView, initializeParams);
        params.put(R_KEY, rViewPut);
        conf.addVariable(R_KEY);
        return params;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        int hiddenLayer;
        OCNNOutputLayer ocnnOutputLayer = (OCNNOutputLayer)conf.getLayer();
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        long nIn = ocnnOutputLayer.getNIn();
        int firstLayerWeightLength = hiddenLayer = ocnnOutputLayer.getHiddenSize();
        long secondLayerLength = nIn * (long)hiddenLayer;
        INDArray weightView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)firstLayerWeightLength)}).reshape('f', 1, hiddenLayer);
        INDArray vView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)firstLayerWeightLength, (long)((long)firstLayerWeightLength + secondLayerLength))}).reshape('f', new long[]{nIn, hiddenLayer});
        params.put(W_KEY, weightView);
        params.put(V_KEY, vView);
        params.put(R_KEY, gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((long)(gradientView.length() - 1L))}));
        return params;
    }

    @Override
    protected INDArray createWeightMatrix(NeuralNetConfiguration configuration, INDArray weightParamView, boolean initializeParameters) {
        OCNNOutputLayer ocnnOutputLayer = (OCNNOutputLayer)configuration.getLayer();
        IWeightInit weightInit = ocnnOutputLayer.getWeightInitFn();
        if (initializeParameters) {
            INDArray ret = weightInit.init(weightParamView.size(0), weightParamView.size(1), weightParamView.shape(), 'f', weightParamView);
            return ret;
        }
        return WeightInitUtil.reshapeWeights(weightParamView.shape(), weightParamView);
    }
}

