/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater.graph;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class ComputationGraphUpdater
implements Serializable,
Cloneable {
    private final Updater[] layerUpdaters;
    private final Map<String, Integer> layerUpdatersMap;
    private INDArray viewArray;

    public ComputationGraphUpdater(ComputationGraph graph) {
        this.layerUpdaters = new Updater[graph.getNumLayers()];
        this.layerUpdatersMap = new HashMap<String, Integer>();
        int i = 0;
        int updaterStateSize = 0;
        for (Layer layer : graph.getLayers()) {
            Updater u;
            this.layerUpdaters[i] = u = UpdaterCreator.getUpdater(layer);
            this.layerUpdatersMap.put(layer.conf().getLayer().getLayerName(), i);
            updaterStateSize += this.layerUpdaters[i].stateSizeForLayer(layer);
            ++i;
        }
        if (updaterStateSize > 0) {
            this.viewArray = Nd4j.createUninitialized((int[])new int[]{1, updaterStateSize}, (char)Nd4j.order().charValue());
        }
        int soFar = 0;
        i = 0;
        for (Layer layer : graph.getLayers()) {
            int thisSize = this.layerUpdaters[i].stateSizeForLayer(layer);
            if (thisSize == 0) continue;
            INDArray view = this.viewArray.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + thisSize))});
            this.layerUpdaters[i++].setStateViewArray(layer, view, true);
            soFar += thisSize;
        }
    }

    public ComputationGraphUpdater(ComputationGraph graph, INDArray updaterState) {
        this.layerUpdatersMap = new HashMap<String, Integer>();
        Layer[] layers = graph.getLayers();
        this.layerUpdaters = new Updater[layers.length];
        int updaterStateSize = 0;
        for (int i = 0; i < layers.length; ++i) {
            this.layerUpdaters[i] = UpdaterCreator.getUpdater(layers[i]);
            updaterStateSize += this.layerUpdaters[i].stateSizeForLayer(layers[i]);
            this.layerUpdatersMap.put(layers[i].conf().getLayer().getLayerName(), i);
        }
        if (updaterState != null) {
            if (updaterState.length() != updaterStateSize) {
                throw new IllegalStateException("Expected updater state with size " + updaterStateSize + ", got size " + updaterState.length());
            }
            this.viewArray = updaterState;
            int soFar = 0;
            for (int i = 0; i < layers.length; ++i) {
                int thisSize = this.layerUpdaters[i].stateSizeForLayer(layers[i]);
                if (thisSize == 0) continue;
                INDArray view = this.viewArray.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)soFar, (int)(soFar + thisSize))});
                this.layerUpdaters[i].setStateViewArray(layers[i], view, false);
                soFar += thisSize;
            }
        } else if (updaterStateSize != 0) {
            throw new IllegalStateException("Expected updater state with size " + updaterStateSize + ", got null input");
        }
    }

    private ComputationGraphUpdater(int size, Map<String, Integer> layerUpdatersMap) {
        this.layerUpdaters = new Updater[size];
        this.layerUpdatersMap = layerUpdatersMap;
    }

    private ComputationGraphUpdater(ComputationGraphUpdater updater) {
        this.layerUpdaters = new Updater[updater.layerUpdaters.length];
        for (int i = 0; i < this.layerUpdaters.length; ++i) {
            this.layerUpdaters[i] = updater.layerUpdaters[i].clone();
        }
        this.layerUpdatersMap = new HashMap<String, Integer>(updater.layerUpdatersMap);
    }

    public ComputationGraphUpdater clone() {
        return new ComputationGraphUpdater(this);
    }

    public void update(ComputationGraph graph, Gradient gradient, int iteration, int batchSize) {
        HashMap<String, Gradient> layerGradients = new HashMap<String, Gradient>();
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            String key = entry.getKey();
            int idx = key.lastIndexOf(95);
            if (idx == -1) {
                throw new IllegalStateException("Invalid key: ComputationGraph Gradient key does not have layer separator: \"" + key + "\"");
            }
            String layerName = key.substring(0, idx);
            Gradient g = (Gradient)layerGradients.get(layerName);
            if (g == null) {
                g = new DefaultGradient();
                layerGradients.put(layerName, g);
            }
            String newKey = key.substring(idx + 1);
            g.setGradientFor(newKey, entry.getValue());
        }
        for (Map.Entry<String, Object> entry : layerGradients.entrySet()) {
            String layerName = entry.getKey();
            int updaterIdx = this.layerUpdatersMap.get(layerName);
            this.layerUpdaters[updaterIdx].update(graph.getLayer(layerName), (Gradient)entry.getValue(), iteration, batchSize);
            for (Map.Entry<String, INDArray> entry2 : ((Gradient)layerGradients.get(layerName)).gradientForVariable().entrySet()) {
                gradient.setGradientFor(entry.getKey() + "_" + entry2.getKey(), entry2.getValue());
            }
        }
    }

    public void setStateViewArray(INDArray viewArray) {
        if (this.viewArray.length() != viewArray.length()) {
            throw new IllegalStateException("Invalid input: view arrays differ in length. Expected length " + this.viewArray.length() + ", got length " + viewArray.length());
        }
        this.viewArray.assign(viewArray);
    }

    public INDArray getStateViewArray() {
        return this.viewArray;
    }

    public boolean equals(Object other) {
        if (!(other instanceof ComputationGraphUpdater)) {
            return false;
        }
        return this.layerUpdatersMap.equals(((ComputationGraphUpdater)other).layerUpdatersMap);
    }

    public int hashCode() {
        return this.layerUpdatersMap.hashCode();
    }
}

