/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class CenterLossParamInitializer
extends DefaultParamInitializer {
    private static final CenterLossParamInitializer INSTANCE = new CenterLossParamInitializer();
    public static final String WEIGHT_KEY = "W";
    public static final String BIAS_KEY = "b";
    public static final String CENTER_KEY = "cL";

    public static CenterLossParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public long numParams(NeuralNetConfiguration conf) {
        FeedForwardLayer layerConf = (FeedForwardLayer)conf.getLayer();
        long nIn = layerConf.getNIn();
        long nOut = layerConf.getNOut();
        return nIn * nOut + nOut + nIn * nOut;
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        CenterLossOutputLayer layerConf = (CenterLossOutputLayer)conf.getLayer();
        long nIn = layerConf.getNIn();
        long nOut = layerConf.getNOut();
        long wEndOffset = nIn * nOut;
        long bEndOffset = wEndOffset + nOut;
        long cEndOffset = bEndOffset + nIn * nOut;
        INDArray weightView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)0L, (long)wEndOffset)});
        INDArray biasView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)wEndOffset, (long)bEndOffset)});
        INDArray centerLossView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)bEndOffset, (long)cEndOffset)}).reshape('c', new long[]{nOut, nIn});
        params.put(WEIGHT_KEY, this.createWeightMatrix(conf, weightView, initializeParams));
        params.put(BIAS_KEY, this.createBias(conf, biasView, initializeParams));
        params.put(CENTER_KEY, this.createCenterLossMatrix(conf, centerLossView, initializeParams));
        conf.addVariable(WEIGHT_KEY);
        conf.addVariable(BIAS_KEY);
        conf.addVariable(CENTER_KEY);
        return params;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        CenterLossOutputLayer layerConf = (CenterLossOutputLayer)conf.getLayer();
        long nIn = layerConf.getNIn();
        long nOut = layerConf.getNOut();
        long wEndOffset = nIn * nOut;
        long bEndOffset = wEndOffset + nOut;
        long cEndOffset = bEndOffset + nIn * nOut;
        INDArray weightGradientView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)0L, (long)wEndOffset)}).reshape('f', new long[]{nIn, nOut});
        INDArray biasView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)wEndOffset, (long)bEndOffset)});
        INDArray centerLossView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)bEndOffset, (long)cEndOffset)}).reshape('c', new long[]{nOut, nIn});
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        out.put(WEIGHT_KEY, weightGradientView);
        out.put(BIAS_KEY, biasView);
        out.put(CENTER_KEY, centerLossView);
        return out;
    }

    protected INDArray createCenterLossMatrix(NeuralNetConfiguration conf, INDArray centerLossView, boolean initializeParameters) {
        CenterLossOutputLayer layerConf = (CenterLossOutputLayer)conf.getLayer();
        if (initializeParameters) {
            centerLossView.assign((Number)0.0);
        }
        return centerLossView;
    }
}

