/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater.graph;

import java.util.Arrays;
import java.util.HashMap;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.updater.BaseMultiLayerUpdater;
import org.nd4j.linalg.api.ndarray.INDArray;

public class ComputationGraphUpdater
extends BaseMultiLayerUpdater<ComputationGraph> {
    protected Trainable[] orderedLayers;

    public ComputationGraphUpdater(ComputationGraph graph) {
        this(graph, (INDArray)null);
    }

    public ComputationGraphUpdater(ComputationGraph graph, INDArray updaterState) {
        super(graph, updaterState);
        Trainable[] layers;
        this.layersByName = new HashMap();
        for (Trainable l : layers = this.getOrderedLayers()) {
            this.layersByName.put(l.getConfig().getLayerName(), l);
        }
    }

    @Override
    protected Trainable[] getOrderedLayers() {
        if (this.orderedLayers != null) {
            return this.orderedLayers;
        }
        GraphVertex[] vertices = ((ComputationGraph)this.network).getVertices();
        int[] topologicalOrdering = ((ComputationGraph)this.network).topologicalSortOrder();
        Trainable[] out = new Trainable[((ComputationGraph)this.network).getVertices().length];
        int j = 0;
        for (int i = 0; i < topologicalOrdering.length; ++i) {
            GraphVertex currentVertex = vertices[topologicalOrdering[i]];
            if (currentVertex.numParams() == 0L) continue;
            out[j++] = currentVertex;
        }
        if (j != out.length) {
            out = Arrays.copyOfRange(out, 0, j);
        }
        this.orderedLayers = out;
        return this.orderedLayers;
    }

    @Override
    public INDArray getFlattenedGradientsView() {
        if (((ComputationGraph)this.network).getFlattenedGradients() == null) {
            ((ComputationGraph)this.network).initGradientsView();
        }
        return ((ComputationGraph)this.network).getFlattenedGradients();
    }

    @Override
    protected INDArray getParams() {
        return ((ComputationGraph)this.network).params();
    }

    @Override
    protected boolean isMiniBatch() {
        return ((ComputationGraph)this.network).conf().isMiniBatch();
    }
}

