/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater;

import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class MultiLayerUpdater
implements Updater {
    private final Updater[] layerUpdaters;
    private INDArray viewArray;

    public MultiLayerUpdater(MultiLayerNetwork network) {
        Layer[] layers = network.getLayers();
        this.layerUpdaters = new Updater[layers.length];
        int updaterStateSize = 0;
        for (int i = 0; i < layers.length; ++i) {
            this.layerUpdaters[i] = UpdaterCreator.getUpdater(layers[i]);
            updaterStateSize += this.layerUpdaters[i].stateSizeForLayer(layers[i]);
        }
        if (updaterStateSize > 0) {
            this.viewArray = Nd4j.createUninitialized((int[])new int[]{1, updaterStateSize}, (char)Nd4j.order().charValue());
        }
        int soFar = 0;
        for (int i = 0; i < layers.length; ++i) {
            int thisSize = this.layerUpdaters[i].stateSizeForLayer(layers[i]);
            if (thisSize == 0) continue;
            INDArray view = this.viewArray.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + thisSize))});
            this.layerUpdaters[i].setStateViewArray(layers[i], view, true);
            soFar += thisSize;
        }
    }

    public MultiLayerUpdater(MultiLayerNetwork network, INDArray updaterState) {
        Layer[] layers = network.getLayers();
        this.layerUpdaters = new Updater[layers.length];
        int updaterStateSize = 0;
        for (int i = 0; i < layers.length; ++i) {
            this.layerUpdaters[i] = UpdaterCreator.getUpdater(layers[i]);
            updaterStateSize += this.layerUpdaters[i].stateSizeForLayer(layers[i]);
        }
        if (updaterState != null) {
            if (updaterState.length() != updaterStateSize) {
                throw new IllegalStateException("Expected updater state with size " + updaterStateSize + ", got size " + updaterState.length());
            }
            this.viewArray = updaterState;
            int soFar = 0;
            for (int i = 0; i < layers.length; ++i) {
                int thisSize = this.layerUpdaters[i].stateSizeForLayer(layers[i]);
                if (thisSize == 0) continue;
                INDArray view = this.viewArray.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + thisSize))});
                this.layerUpdaters[i].setStateViewArray(layers[i], view, false);
                soFar += thisSize;
            }
        } else if (updaterStateSize != 0) {
            throw new IllegalStateException("Expected updater state with size " + updaterStateSize + ", got null input");
        }
    }

    @Override
    public void setStateViewArray(Layer layer, INDArray viewArray, boolean initialize) {
        if (this.viewArray.length() != viewArray.length()) {
            throw new IllegalStateException("Invalid input: view arrays differ in length. Expected length " + this.viewArray.length() + ", got length " + viewArray.length());
        }
        this.viewArray.assign(viewArray);
    }

    @Override
    public INDArray getStateViewArray() {
        return this.viewArray;
    }

    @Override
    public int stateSizeForLayer(Layer layer) {
        if (!(layer instanceof MultiLayerNetwork)) {
            throw new IllegalArgumentException("Expected MultiLayerNetwork");
        }
        return this.viewArray.length();
    }

    @Override
    public void update(Layer layer, Gradient gradient, int iteration, int batchSize) {
        MultiLayerNetwork mln = (MultiLayerNetwork)layer;
        Gradient[] layerGradients = new Gradient[this.layerUpdaters.length];
        for (int i = 0; i < layerGradients.length; ++i) {
            layerGradients[i] = new DefaultGradient();
        }
        for (Map.Entry<String, INDArray> gradientPair : gradient.gradientForVariable().entrySet()) {
            String key = gradientPair.getKey();
            int idx = key.indexOf(95);
            if (idx == -1) {
                throw new IllegalStateException("Invalid key: MuliLayerNetwork Gradient key does not have layer separator: \"" + key + "\"");
            }
            int layerIdx = Integer.parseInt(key.substring(0, idx));
            String newKey = key.substring(idx + 1);
            layerGradients[layerIdx].gradientForVariable().put(newKey, gradientPair.getValue());
        }
        for (int i = 0; i < this.layerUpdaters.length; ++i) {
            this.layerUpdaters[i].update(mln.getLayer(i), layerGradients[i], iteration, batchSize);
        }
    }

    @Override
    public Updater clone() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    public boolean equals(Object other) {
        if (!(other instanceof MultiLayerUpdater)) {
            return false;
        }
        MultiLayerUpdater multiLayerUpdater = (MultiLayerUpdater)other;
        if (this.layerUpdaters.length != multiLayerUpdater.layerUpdaters.length) {
            return false;
        }
        for (int i = 0; i < this.layerUpdaters.length; ++i) {
            if (this.layerUpdaters[i].equals(multiLayerUpdater.layerUpdaters[i])) continue;
            return false;
        }
        return true;
    }
}

