/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater;

import java.util.Arrays;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;
import org.nd4j.linalg.api.ndarray.INDArray;

public class MultiLayerUpdater
implements Updater {
    private final Updater[] layerUpdaters;

    public MultiLayerUpdater(MultiLayerNetwork network) {
        Layer[] layers = network.getLayers();
        this.layerUpdaters = new Updater[layers.length];
        for (int i = 0; i < layers.length; ++i) {
            this.layerUpdaters[i] = UpdaterCreator.getUpdater(layers[i]);
        }
    }

    private MultiLayerUpdater(int size) {
        this.layerUpdaters = new Updater[size];
    }

    @Override
    public void update(Layer layer, Gradient gradient, int iteration, int batchSize) {
        MultiLayerNetwork mln = (MultiLayerNetwork)layer;
        Gradient[] layerGradients = new Gradient[this.layerUpdaters.length];
        for (int i = 0; i < layerGradients.length; ++i) {
            layerGradients[i] = new DefaultGradient();
        }
        for (Map.Entry<String, INDArray> gradientPair : gradient.gradientForVariable().entrySet()) {
            String key = gradientPair.getKey();
            int idx = key.indexOf("_");
            if (idx == -1) {
                throw new IllegalStateException("Invalid key: MuliLayerNetwork Gradient key does not have layer separator: \"" + key + "\"");
            }
            int layerIdx = Integer.parseInt(key.substring(0, idx));
            String newKey = key.substring(idx + 1);
            layerGradients[layerIdx].gradientForVariable().put(newKey, gradientPair.getValue());
        }
        for (int i = 0; i < this.layerUpdaters.length; ++i) {
            this.layerUpdaters[i].update(mln.getLayer(i), layerGradients[i], iteration, batchSize);
            for (Map.Entry<String, INDArray> entry : layerGradients[i].gradientForVariable().entrySet()) {
                gradient.setGradientFor(i + "_" + entry.getKey(), entry.getValue());
            }
        }
    }

    @Override
    public UpdaterAggregator getAggregator(boolean addThis) {
        MultiLayerUpdaterAggregator ag = new MultiLayerUpdaterAggregator();
        if (addThis) {
            ag.aggregate(this);
        }
        return ag;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MultiLayerUpdater)) {
            return false;
        }
        MultiLayerUpdater other = (MultiLayerUpdater)o;
        if (!other.canEqual(this)) {
            return false;
        }
        return Arrays.deepEquals(this.layerUpdaters, other.layerUpdaters);
    }

    protected boolean canEqual(Object other) {
        return other instanceof MultiLayerUpdater;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + Arrays.deepHashCode(this.layerUpdaters);
        return result;
    }

    protected static class MultiLayerUpdaterAggregator
    implements UpdaterAggregator {
        private UpdaterAggregator[] aggregators;

        protected MultiLayerUpdaterAggregator() {
        }

        @Override
        public void aggregate(Updater updater) {
            MultiLayerUpdater mlu = (MultiLayerUpdater)updater;
            if (this.aggregators == null) {
                this.aggregators = new UpdaterAggregator[mlu.layerUpdaters.length];
                for (int i = 0; i < this.aggregators.length; ++i) {
                    this.aggregators[i] = mlu.layerUpdaters[i].getAggregator(true);
                }
            } else {
                for (int i = 0; i < this.aggregators.length; ++i) {
                    this.aggregators[i].aggregate(mlu.layerUpdaters[i]);
                }
            }
        }

        @Override
        public void merge(UpdaterAggregator aggregator) {
            MultiLayerUpdaterAggregator mlua = (MultiLayerUpdaterAggregator)aggregator;
            if (this.aggregators == null) {
                this.aggregators = mlua.aggregators;
            } else {
                if (mlua.aggregators == null) {
                    return;
                }
                for (int i = 0; i < this.aggregators.length; ++i) {
                    this.aggregators[i].merge(mlua.aggregators[i]);
                }
            }
        }

        @Override
        public Updater getUpdater() {
            MultiLayerUpdater multiLayerUpdater = new MultiLayerUpdater(this.aggregators.length);
            for (int i = 0; i < this.aggregators.length; ++i) {
                ((MultiLayerUpdater)multiLayerUpdater).layerUpdaters[i] = this.aggregators[i].getUpdater();
            }
            return multiLayerUpdater;
        }
    }
}

