/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.util;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.MultiLayerUpdater;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.schedule.ISchedule;

public class NetworkUtils {
    private NetworkUtils() {
    }

    public static ComputationGraph toComputationGraph(MultiLayerNetwork net) {
        ComputationGraphConfiguration.GraphBuilder b = new NeuralNetConfiguration.Builder().graphBuilder();
        MultiLayerConfiguration origConf = net.getLayerWiseConfigurations().clone();
        int layerIdx = 0;
        String lastLayer = "in";
        b.addInputs("in");
        for (NeuralNetConfiguration c : origConf.getConfs()) {
            String currLayer = String.valueOf(layerIdx);
            InputPreProcessor preproc = origConf.getInputPreProcess(layerIdx);
            b.addLayer(currLayer, c.getLayer(), preproc, lastLayer);
            lastLayer = currLayer;
            ++layerIdx;
        }
        b.setOutputs(lastLayer);
        ComputationGraphConfiguration conf = b.build();
        ComputationGraph cg = new ComputationGraph(conf);
        cg.init();
        cg.setParams(net.params());
        INDArray updaterState = net.getUpdater().getStateViewArray();
        if (updaterState != null) {
            cg.getUpdater().getUpdaterStateViewArray().assign(updaterState);
        }
        return cg;
    }

    public static void setLearningRate(MultiLayerNetwork net, double newLr) {
        NetworkUtils.setLearningRate(net, newLr, null);
    }

    public static void setLearningRate(MultiLayerNetwork net, ISchedule newLrSchedule) {
        NetworkUtils.setLearningRate(net, Double.NaN, newLrSchedule);
    }

    private static void setLearningRate(MultiLayerNetwork net, double newLr, ISchedule lrSchedule) {
        int nLayers = net.getnLayers();
        for (int i = 0; i < nLayers; ++i) {
            NetworkUtils.setLearningRate(net, i, newLr, lrSchedule, false);
        }
        NetworkUtils.refreshUpdater(net);
    }

    public static void setLearningRate(MultiLayerNetwork net, int layerNumber, double newLr) {
        NetworkUtils.setLearningRate(net, layerNumber, newLr, null, true);
    }

    public static void setLearningRate(MultiLayerNetwork net, int layerNumber, ISchedule lrSchedule) {
        NetworkUtils.setLearningRate(net, layerNumber, Double.NaN, lrSchedule, true);
    }

    private static void setLearningRate(MultiLayerNetwork net, int layerNumber, double newLr, ISchedule newLrSchedule, boolean refreshUpdater) {
        org.deeplearning4j.nn.conf.layers.Layer l = net.getLayer(layerNumber).conf().getLayer();
        if (l instanceof BaseLayer) {
            BaseLayer bl = (BaseLayer)l;
            IUpdater u = bl.getIUpdater();
            if (u != null && u.hasLearningRate()) {
                if (newLrSchedule != null) {
                    u.setLrAndSchedule(Double.NaN, newLrSchedule);
                } else {
                    u.setLrAndSchedule(newLr, null);
                }
            }
            if (refreshUpdater) {
                NetworkUtils.refreshUpdater(net);
            }
        }
    }

    private static void refreshUpdater(MultiLayerNetwork net) {
        INDArray origUpdaterState = net.getUpdater().getStateViewArray();
        net.setUpdater(null);
        MultiLayerUpdater u = (MultiLayerUpdater)net.getUpdater();
        u.setStateViewArray(origUpdaterState);
    }

    public static void setLearningRate(ComputationGraph net, double newLr) {
        NetworkUtils.setLearningRate(net, newLr, null);
    }

    public static void setLearningRate(ComputationGraph net, ISchedule newLrSchedule) {
        NetworkUtils.setLearningRate(net, Double.NaN, newLrSchedule);
    }

    private static void setLearningRate(ComputationGraph net, double newLr, ISchedule lrSchedule) {
        Layer[] layers = net.getLayers();
        for (int i = 0; i < layers.length; ++i) {
            NetworkUtils.setLearningRate(net, layers[i].conf().getLayer().getLayerName(), newLr, lrSchedule, false);
        }
        NetworkUtils.refreshUpdater(net);
    }

    public static void setLearningRate(ComputationGraph net, String layerName, double newLr) {
        NetworkUtils.setLearningRate(net, layerName, newLr, null, true);
    }

    public static void setLearningRate(ComputationGraph net, String layerName, ISchedule lrSchedule) {
        NetworkUtils.setLearningRate(net, layerName, Double.NaN, lrSchedule, true);
    }

    private static void setLearningRate(ComputationGraph net, String layerName, double newLr, ISchedule newLrSchedule, boolean refreshUpdater) {
        org.deeplearning4j.nn.conf.layers.Layer l = net.getLayer(layerName).conf().getLayer();
        if (l instanceof BaseLayer) {
            BaseLayer bl = (BaseLayer)l;
            IUpdater u = bl.getIUpdater();
            if (u != null && u.hasLearningRate()) {
                if (newLrSchedule != null) {
                    u.setLrAndSchedule(Double.NaN, newLrSchedule);
                } else {
                    u.setLrAndSchedule(newLr, null);
                }
            }
            if (refreshUpdater) {
                NetworkUtils.refreshUpdater(net);
            }
        }
    }

    private static void refreshUpdater(ComputationGraph net) {
        INDArray origUpdaterState = net.getUpdater().getStateViewArray();
        net.setUpdater(null);
        ComputationGraphUpdater u = net.getUpdater();
        u.setStateViewArray(origUpdaterState);
    }
}

