/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class VariationalAutoencoderParamInitializer
extends DefaultParamInitializer {
    private static final VariationalAutoencoderParamInitializer INSTANCE = new VariationalAutoencoderParamInitializer();
    public static final String WEIGHT_KEY_SUFFIX = "W";
    public static final String BIAS_KEY_SUFFIX = "b";
    public static final String PZX_PREFIX = "pZX";
    public static final String PZX_MEAN_PREFIX = "pZXMean";
    public static final String PZX_LOGSTD2_PREFIX = "pZXLogStd2";
    public static final String PZX_MEAN_W = "pZXMeanW";
    public static final String PZX_MEAN_B = "pZXMeanb";
    public static final String PZX_LOGSTD2_W = "pZXLogStd2W";
    public static final String PZX_LOGSTD2_B = "pZXLogStd2b";
    public static final String PXZ_PREFIX = "pXZ";
    public static final String PXZ_W = "pXZW";
    public static final String PXZ_B = "pXZb";

    public static VariationalAutoencoderParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public int numParams(NeuralNetConfiguration conf) {
        VariationalAutoencoder layer = (VariationalAutoencoder)conf.getLayer();
        int nIn = layer.getNIn();
        int nOut = layer.getNOut();
        int[] encoderLayerSizes = layer.getEncoderLayerSizes();
        int[] decoderLayerSizes = layer.getDecoderLayerSizes();
        int paramCount = 0;
        for (int i = 0; i < encoderLayerSizes.length; ++i) {
            int encoderLayerIn = i == 0 ? nIn : encoderLayerSizes[i - 1];
            paramCount += (encoderLayerIn + 1) * encoderLayerSizes[i];
        }
        int lastEncLayerSize = encoderLayerSizes[encoderLayerSizes.length - 1];
        paramCount += (lastEncLayerSize + 1) * 2 * nOut;
        for (int i = 0; i < decoderLayerSizes.length; ++i) {
            int decoderLayerNIn = i == 0 ? nOut : decoderLayerSizes[i - 1];
            paramCount += (decoderLayerNIn + 1) * decoderLayerSizes[i];
        }
        int nDistributionParams = layer.getOutputDistribution().distributionInputSize(nIn);
        int lastDecLayerSize = decoderLayerSizes[decoderLayerSizes.length - 1];
        return paramCount += (lastDecLayerSize + 1) * nDistributionParams;
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        if (paramsView.length() != this.numParams(conf)) {
            throw new IllegalArgumentException("Incorrect paramsView length: Expected length " + this.numParams(conf) + ", got length " + paramsView.length());
        }
        LinkedHashMap<String, INDArray> ret = new LinkedHashMap<String, INDArray>();
        VariationalAutoencoder layer = (VariationalAutoencoder)conf.getLayer();
        int nIn = layer.getNIn();
        int nOut = layer.getNOut();
        int[] encoderLayerSizes = layer.getEncoderLayerSizes();
        int[] decoderLayerSizes = layer.getDecoderLayerSizes();
        WeightInit weightInit = layer.getWeightInit();
        Distribution dist = Distributions.createDistribution(layer.getDist());
        int soFar = 0;
        for (int i = 0; i < encoderLayerSizes.length; ++i) {
            int encoderLayerNIn = i == 0 ? nIn : encoderLayerSizes[i - 1];
            int weightParamCount = encoderLayerNIn * encoderLayerSizes[i];
            INDArray weightView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + weightParamCount))});
            INDArray biasView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(soFar += weightParamCount), (int)(soFar + encoderLayerSizes[i]))});
            soFar += encoderLayerSizes[i];
            INDArray layerWeights = this.createWeightMatrix(encoderLayerNIn, encoderLayerSizes[i], weightInit, dist, weightView, initializeParams);
            INDArray layerBiases = this.createBias(encoderLayerSizes[i], 0.0, biasView, initializeParams);
            String sW = "e" + i + WEIGHT_KEY_SUFFIX;
            String sB = "e" + i + BIAS_KEY_SUFFIX;
            ret.put(sW, layerWeights);
            ret.put(sB, layerBiases);
            conf.addVariable(sW);
            conf.addVariable(sB);
        }
        int nWeightsPzx = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
        INDArray pzxWeightsMean = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + nWeightsPzx))});
        INDArray pzxBiasMean = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(soFar += nWeightsPzx), (int)(soFar + nOut))});
        soFar += nOut;
        INDArray pzxWeightsMeanReshaped = this.createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, weightInit, dist, pzxWeightsMean, initializeParams);
        INDArray pzxBiasMeanReshaped = this.createBias(nOut, 0.0, pzxBiasMean, initializeParams);
        ret.put(PZX_MEAN_W, pzxWeightsMeanReshaped);
        ret.put(PZX_MEAN_B, pzxBiasMeanReshaped);
        conf.addVariable(PZX_MEAN_W);
        conf.addVariable(PZX_MEAN_B);
        INDArray pzxWeightsLogStdev2 = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + nWeightsPzx))});
        INDArray pzxBiasLogStdev2 = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(soFar += nWeightsPzx), (int)(soFar + nOut))});
        soFar += nOut;
        INDArray pzxWeightsLogStdev2Reshaped = this.createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, weightInit, dist, pzxWeightsLogStdev2, initializeParams);
        INDArray pzxBiasLogStdev2Reshaped = this.createBias(nOut, 0.0, pzxBiasLogStdev2, initializeParams);
        ret.put(PZX_LOGSTD2_W, pzxWeightsLogStdev2Reshaped);
        ret.put(PZX_LOGSTD2_B, pzxBiasLogStdev2Reshaped);
        conf.addVariable(PZX_LOGSTD2_W);
        conf.addVariable(PZX_LOGSTD2_B);
        for (int i = 0; i < decoderLayerSizes.length; ++i) {
            int decoderLayerNIn = i == 0 ? nOut : decoderLayerSizes[i - 1];
            int weightParamCount = decoderLayerNIn * decoderLayerSizes[i];
            INDArray weightView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + weightParamCount))});
            INDArray biasView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(soFar += weightParamCount), (int)(soFar + decoderLayerSizes[i]))});
            soFar += decoderLayerSizes[i];
            INDArray layerWeights = this.createWeightMatrix(decoderLayerNIn, decoderLayerSizes[i], weightInit, dist, weightView, initializeParams);
            INDArray layerBiases = this.createBias(decoderLayerSizes[i], 0.0, biasView, initializeParams);
            String sW = "d" + i + WEIGHT_KEY_SUFFIX;
            String sB = "d" + i + BIAS_KEY_SUFFIX;
            ret.put(sW, layerWeights);
            ret.put(sB, layerBiases);
            conf.addVariable(sW);
            conf.addVariable(sB);
        }
        int nDistributionParams = layer.getOutputDistribution().distributionInputSize(nIn);
        int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
        INDArray pxzWeightView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + pxzWeightCount))});
        INDArray pxzBiasView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(soFar += pxzWeightCount), (int)(soFar + nDistributionParams))});
        INDArray pxzWeightsReshaped = this.createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1], nDistributionParams, weightInit, dist, pxzWeightView, initializeParams);
        INDArray pxzBiasReshaped = this.createBias(nDistributionParams, 0.0, pxzBiasView, initializeParams);
        ret.put(PXZ_W, pxzWeightsReshaped);
        ret.put(PXZ_B, pxzBiasReshaped);
        conf.addVariable(PXZ_W);
        conf.addVariable(PXZ_B);
        return ret;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        LinkedHashMap<String, INDArray> ret = new LinkedHashMap<String, INDArray>();
        VariationalAutoencoder layer = (VariationalAutoencoder)conf.getLayer();
        int nIn = layer.getNIn();
        int nOut = layer.getNOut();
        int[] encoderLayerSizes = layer.getEncoderLayerSizes();
        int[] decoderLayerSizes = layer.getDecoderLayerSizes();
        int soFar = 0;
        for (int i = 0; i < encoderLayerSizes.length; ++i) {
            int encoderLayerNIn = i == 0 ? nIn : encoderLayerSizes[i - 1];
            int weightParamCount = encoderLayerNIn * encoderLayerSizes[i];
            INDArray weightGradView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + weightParamCount))});
            INDArray biasGradView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(soFar += weightParamCount), (int)(soFar + encoderLayerSizes[i]))});
            soFar += encoderLayerSizes[i];
            INDArray layerWeights = weightGradView.reshape('f', encoderLayerNIn, encoderLayerSizes[i]);
            INDArray layerBiases = biasGradView;
            ret.put("e" + i + WEIGHT_KEY_SUFFIX, layerWeights);
            ret.put("e" + i + BIAS_KEY_SUFFIX, layerBiases);
        }
        int nWeightsPzx = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
        INDArray pzxWeightsMean = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + nWeightsPzx))});
        INDArray pzxBiasMean = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(soFar += nWeightsPzx), (int)(soFar + nOut))});
        soFar += nOut;
        INDArray pzxWeightGradMeanReshaped = pzxWeightsMean.reshape('f', encoderLayerSizes[encoderLayerSizes.length - 1], nOut);
        ret.put(PZX_MEAN_W, pzxWeightGradMeanReshaped);
        ret.put(PZX_MEAN_B, pzxBiasMean);
        INDArray pzxWeightsLogStdev2 = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + nWeightsPzx))});
        INDArray pzxBiasLogStdev2 = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(soFar += nWeightsPzx), (int)(soFar + nOut))});
        soFar += nOut;
        INDArray pzxWeightsLogStdev2Reshaped = this.createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, null, null, pzxWeightsLogStdev2, false);
        ret.put(PZX_LOGSTD2_W, pzxWeightsLogStdev2Reshaped);
        ret.put(PZX_LOGSTD2_B, pzxBiasLogStdev2);
        for (int i = 0; i < decoderLayerSizes.length; ++i) {
            int decoderLayerNIn = i == 0 ? nOut : decoderLayerSizes[i - 1];
            int weightParamCount = decoderLayerNIn * decoderLayerSizes[i];
            INDArray weightView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + weightParamCount))});
            INDArray biasView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(soFar += weightParamCount), (int)(soFar + decoderLayerSizes[i]))});
            soFar += decoderLayerSizes[i];
            INDArray layerWeights = this.createWeightMatrix(decoderLayerNIn, decoderLayerSizes[i], null, null, weightView, false);
            INDArray layerBiases = this.createBias(decoderLayerSizes[i], 0.0, biasView, false);
            String sW = "d" + i + WEIGHT_KEY_SUFFIX;
            String sB = "d" + i + BIAS_KEY_SUFFIX;
            ret.put(sW, layerWeights);
            ret.put(sB, layerBiases);
        }
        int nDistributionParams = layer.getOutputDistribution().distributionInputSize(nIn);
        int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
        INDArray pxzWeightView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + pxzWeightCount))});
        INDArray pxzBiasView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(soFar += pxzWeightCount), (int)(soFar + nDistributionParams))});
        INDArray pxzWeightsReshaped = this.createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1], nDistributionParams, null, null, pxzWeightView, false);
        INDArray pxzBiasReshaped = this.createBias(nDistributionParams, 0.0, pxzBiasView, false);
        ret.put(PXZ_W, pxzWeightsReshaped);
        ret.put(PXZ_B, pxzBiasReshaped);
        return ret;
    }
}

