/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.transferlearning;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TransferLearning {
    private static final Logger log = LoggerFactory.getLogger(TransferLearning.class);

    public static class GraphBuilder {
        private ComputationGraph origGraph;
        private ComputationGraphConfiguration origConfig;
        private FineTuneConfiguration fineTuneConfiguration;
        private ComputationGraphConfiguration.GraphBuilder editedConfigBuilder;
        private String[] frozenOutputAt;
        private boolean hasFrozen = false;
        private Set<String> editedVertices = new HashSet<String>();
        private WorkspaceMode workspaceMode;

        public GraphBuilder(ComputationGraph origGraph) {
            this.origGraph = origGraph;
            this.origConfig = origGraph.getConfiguration().clone();
        }

        public GraphBuilder fineTuneConfiguration(FineTuneConfiguration fineTuneConfiguration) {
            this.fineTuneConfiguration = fineTuneConfiguration;
            this.editedConfigBuilder = new ComputationGraphConfiguration.GraphBuilder(this.origConfig, fineTuneConfiguration.appliedNeuralNetConfigurationBuilder());
            Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> vertices = this.editedConfigBuilder.getVertices();
            for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> gv : vertices.entrySet()) {
                if (!(gv.getValue() instanceof LayerVertex)) continue;
                LayerVertex lv = (LayerVertex)gv.getValue();
                NeuralNetConfiguration nnc = lv.getLayerConf().clone();
                fineTuneConfiguration.applyToNeuralNetConfiguration(nnc);
                vertices.put(gv.getKey(), new LayerVertex(nnc, lv.getPreProcessor()));
                nnc.getLayer().setLayerName(gv.getKey());
            }
            return this;
        }

        public GraphBuilder setFeatureExtractor(String ... layerName) {
            this.hasFrozen = true;
            this.frozenOutputAt = layerName;
            return this;
        }

        public GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme) {
            return this.nOutReplace(layerName, nOut, scheme, scheme, null, null);
        }

        public GraphBuilder nOutReplace(String layerName, int nOut, Distribution dist) {
            return this.nOutReplace(layerName, nOut, WeightInit.DISTRIBUTION, WeightInit.DISTRIBUTION, dist, dist);
        }

        public GraphBuilder nOutReplace(String layerName, int nOut, Distribution dist, Distribution distNext) {
            return this.nOutReplace(layerName, nOut, WeightInit.DISTRIBUTION, WeightInit.DISTRIBUTION, dist, distNext);
        }

        public GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme, Distribution dist) {
            return this.nOutReplace(layerName, nOut, scheme, WeightInit.DISTRIBUTION, null, dist);
        }

        public GraphBuilder nOutReplace(String layerName, int nOut, Distribution dist, WeightInit scheme) {
            return this.nOutReplace(layerName, nOut, WeightInit.DISTRIBUTION, scheme, dist, null);
        }

        public GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme, WeightInit schemeNext) {
            return this.nOutReplace(layerName, nOut, scheme, schemeNext, null, null);
        }

        private GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme, WeightInit schemeNext, Distribution dist, Distribution distNext) {
            this.initBuilderIfReq();
            if (this.origGraph.getVertex(layerName).hasLayer()) {
                NeuralNetConfiguration layerConf = this.origGraph.getLayer(layerName).conf();
                Layer layerImpl = layerConf.getLayer().clone();
                layerImpl.resetLayerDefaultConfig();
                FeedForwardLayer layerImplF = (FeedForwardLayer)layerImpl;
                layerImplF.setWeightInit(scheme);
                layerImplF.setDist(dist);
                layerImplF.setNOut(nOut);
                this.editedConfigBuilder.removeVertex(layerName, false);
                LayerVertex lv = (LayerVertex)this.origConfig.getVertices().get(layerName);
                String[] lvInputs = this.origConfig.getVertexInputs().get(layerName).toArray(new String[0]);
                this.editedConfigBuilder.addLayer(layerName, layerImpl, lv.getPreProcessor(), lvInputs);
                this.editedVertices.add(layerName);
                ArrayList<String> fanoutVertices = new ArrayList<String>();
                for (Map.Entry<String, List<String>> entry : this.origConfig.getVertexInputs().entrySet()) {
                    String currentVertex = entry.getKey();
                    if (currentVertex.equals(layerName) || !entry.getValue().contains(layerName)) continue;
                    fanoutVertices.add(currentVertex);
                }
                for (String fanoutVertexName : fanoutVertices) {
                    if (!this.origGraph.getVertex(fanoutVertexName).hasLayer()) {
                        throw new UnsupportedOperationException("Cannot modify nOut of a layer vertex that feeds non-layer vertices. Use removeVertexKeepConnections followed by addVertex instead");
                    }
                    layerConf = this.origGraph.getLayer(fanoutVertexName).conf();
                    layerImpl = layerConf.getLayer().clone();
                    layerImplF = (FeedForwardLayer)layerImpl;
                    layerImplF.setWeightInit(schemeNext);
                    layerImplF.setDist(distNext);
                    layerImplF.setNIn(nOut);
                    this.editedConfigBuilder.removeVertex(fanoutVertexName, false);
                    lv = (LayerVertex)this.origConfig.getVertices().get(fanoutVertexName);
                    lvInputs = this.origConfig.getVertexInputs().get(fanoutVertexName).toArray(new String[0]);
                    this.editedConfigBuilder.addLayer(fanoutVertexName, layerImpl, lv.getPreProcessor(), lvInputs);
                    this.editedVertices.add(fanoutVertexName);
                }
            } else {
                throw new IllegalArgumentException("noutReplace can only be applied to layer vertices. " + layerName + " is not a layer vertex");
            }
            return this;
        }

        public GraphBuilder removeVertexKeepConnections(String outputName) {
            this.initBuilderIfReq();
            this.editedConfigBuilder.removeVertex(outputName, false);
            return this;
        }

        public GraphBuilder removeVertexAndConnections(String vertexName) {
            this.initBuilderIfReq();
            this.editedConfigBuilder.removeVertex(vertexName, true);
            return this;
        }

        public GraphBuilder addLayer(String layerName, Layer layer, String ... layerInputs) {
            this.initBuilderIfReq();
            this.editedConfigBuilder.addLayer(layerName, layer, (InputPreProcessor)null, layerInputs);
            this.editedVertices.add(layerName);
            return this;
        }

        public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor preProcessor, String ... layerInputs) {
            this.initBuilderIfReq();
            this.editedConfigBuilder.addLayer(layerName, layer, preProcessor, layerInputs);
            this.editedVertices.add(layerName);
            return this;
        }

        public GraphBuilder addVertex(String vertexName, org.deeplearning4j.nn.conf.graph.GraphVertex vertex, String ... vertexInputs) {
            this.initBuilderIfReq();
            this.editedConfigBuilder.addVertex(vertexName, vertex, vertexInputs);
            this.editedVertices.add(vertexName);
            return this;
        }

        public GraphBuilder setOutputs(String ... outputNames) {
            this.initBuilderIfReq();
            this.editedConfigBuilder.setOutputs(outputNames);
            return this;
        }

        private void initBuilderIfReq() {
            if (this.editedConfigBuilder == null) {
                this.fineTuneConfiguration(new FineTuneConfiguration.Builder().seed(this.origConfig.getDefaultConfiguration().getSeed()).build());
            }
        }

        public GraphBuilder setInputs(String ... inputs) {
            this.editedConfigBuilder.setNetworkInputs(Arrays.asList(inputs));
            return this;
        }

        public GraphBuilder setInputTypes(InputType ... inputTypes) {
            this.editedConfigBuilder.setInputTypes(inputTypes);
            return this;
        }

        public GraphBuilder addInputs(String ... inputNames) {
            this.editedConfigBuilder.addInputs(inputNames);
            return this;
        }

        public GraphBuilder setWorkspaceMode(WorkspaceMode workspaceMode) {
            this.workspaceMode = workspaceMode;
            return this;
        }

        public ComputationGraph build() {
            this.initBuilderIfReq();
            ComputationGraphConfiguration newConfig = this.editedConfigBuilder.build();
            if (this.workspaceMode != null) {
                newConfig.setTrainingWorkspaceMode(this.workspaceMode);
            }
            ComputationGraph newGraph = new ComputationGraph(newConfig);
            newGraph.init();
            int[] topologicalOrder = newGraph.topologicalSortOrder();
            GraphVertex[] vertices = newGraph.getVertices();
            if (!this.editedVertices.isEmpty()) {
                for (int i = 0; i < topologicalOrder.length; ++i) {
                    if (!vertices[topologicalOrder[i]].hasLayer()) continue;
                    org.deeplearning4j.nn.api.Layer layer = vertices[topologicalOrder[i]].getLayer();
                    String layerName = vertices[topologicalOrder[i]].getVertexName();
                    int range = layer.numParams();
                    if (range <= 0 || this.editedVertices.contains(layerName)) continue;
                    layer.setParams(this.origGraph.getLayer(layerName).params().dup());
                }
            } else {
                newGraph.setParams(this.origGraph.params());
            }
            if (this.hasFrozen) {
                HashSet<String> allFrozen = new HashSet<String>();
                Collections.addAll(allFrozen, this.frozenOutputAt);
                for (int i = topologicalOrder.length - 1; i >= 0; --i) {
                    VertexIndices[] inputs;
                    GraphVertex gv = vertices[topologicalOrder[i]];
                    if (!allFrozen.contains(gv.getVertexName())) continue;
                    if (gv.hasLayer()) {
                        org.deeplearning4j.nn.api.Layer l = gv.getLayer();
                        gv.setLayerAsFrozen();
                        String layerName = gv.getVertexName();
                        LayerVertex currLayerVertex = (LayerVertex)newConfig.getVertices().get(layerName);
                        Layer origLayerConf = currLayerVertex.getLayerConf().getLayer();
                        FrozenLayer newLayerConf = new FrozenLayer(origLayerConf);
                        ((Layer)newLayerConf).setLayerName(origLayerConf.getLayerName());
                        NeuralNetConfiguration newNNC = currLayerVertex.getLayerConf().clone();
                        currLayerVertex.setLayerConf(newNNC);
                        currLayerVertex.getLayerConf().setLayer(newLayerConf);
                        List<String> vars = currLayerVertex.getLayerConf().variables(true);
                        currLayerVertex.getLayerConf().clearVariables();
                        for (String s : vars) {
                            newNNC.variables(false).add(s);
                            newNNC.getL1ByParam().put(s, 0.0);
                            newNNC.getL2ByParam().put(s, 0.0);
                        }
                        org.deeplearning4j.nn.api.Layer[] layers = newGraph.getLayers();
                        for (int j = 0; j < layers.length; ++j) {
                            if (layers[j] != l) continue;
                            layers[j] = gv.getLayer();
                            break;
                        }
                    }
                    if ((inputs = gv.getInputVertices()) == null || inputs.length <= 0) continue;
                    for (int j = 0; j < inputs.length; ++j) {
                        int inputVertexIdx = inputs[j].getVertexIndex();
                        String alsoFreeze = vertices[inputVertexIdx].getVertexName();
                        allFrozen.add(alsoFreeze);
                    }
                }
                newGraph.initGradientsView();
            }
            return newGraph;
        }
    }

    public static class Builder {
        private MultiLayerConfiguration origConf;
        private MultiLayerNetwork origModel;
        private MultiLayerNetwork editedModel;
        private FineTuneConfiguration finetuneConfiguration;
        private int frozenTill = -1;
        private int popN = 0;
        private boolean prepDone = false;
        private Set<Integer> editedLayers = new HashSet<Integer>();
        private Map<Integer, Triple<Integer, Pair<WeightInit, Distribution>, Pair<WeightInit, Distribution>>> editedLayersMap = new HashMap<Integer, Triple<Integer, Pair<WeightInit, Distribution>, Pair<WeightInit, Distribution>>>();
        private List<INDArray> editedParams = new ArrayList<INDArray>();
        private List<NeuralNetConfiguration> editedConfs = new ArrayList<NeuralNetConfiguration>();
        private List<INDArray> appendParams = new ArrayList<INDArray>();
        private List<NeuralNetConfiguration> appendConfs = new ArrayList<NeuralNetConfiguration>();
        private Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<Integer, InputPreProcessor>();
        private InputType inputType;

        public Builder(MultiLayerNetwork origModel) {
            this.origModel = origModel;
            this.origConf = origModel.getLayerWiseConfigurations().clone();
            this.inputPreProcessors = this.origConf.getInputPreProcessors();
        }

        public Builder fineTuneConfiguration(FineTuneConfiguration finetuneConfiguration) {
            this.finetuneConfiguration = finetuneConfiguration;
            return this;
        }

        public Builder setFeatureExtractor(int layerNum) {
            this.frozenTill = layerNum;
            return this;
        }

        public Builder nOutReplace(int layerNum, int nOut, WeightInit scheme) {
            return this.nOutReplace(layerNum, nOut, scheme, scheme, null, null);
        }

        public Builder nOutReplace(int layerNum, int nOut, Distribution dist) {
            return this.nOutReplace(layerNum, nOut, WeightInit.DISTRIBUTION, WeightInit.DISTRIBUTION, dist, dist);
        }

        public Builder nOutReplace(int layerNum, int nOut, WeightInit scheme, WeightInit schemeNext) {
            return this.nOutReplace(layerNum, nOut, scheme, schemeNext, null, null);
        }

        public Builder nOutReplace(int layerNum, int nOut, Distribution dist, Distribution distNext) {
            return this.nOutReplace(layerNum, nOut, WeightInit.DISTRIBUTION, WeightInit.DISTRIBUTION, dist, distNext);
        }

        public Builder nOutReplace(int layerNum, int nOut, WeightInit scheme, Distribution distNext) {
            return this.nOutReplace(layerNum, nOut, scheme, WeightInit.DISTRIBUTION, null, distNext);
        }

        public Builder nOutReplace(int layerNum, int nOut, Distribution dist, WeightInit schemeNext) {
            return this.nOutReplace(layerNum, nOut, WeightInit.DISTRIBUTION, schemeNext, dist, null);
        }

        private Builder nOutReplace(int layerNum, int nOut, WeightInit scheme, WeightInit schemeNext, Distribution dist, Distribution distNext) {
            this.editedLayers.add(layerNum);
            Triple t = new Triple((Object)nOut, (Object)new Pair((Object)scheme, (Object)dist), (Object)new Pair((Object)schemeNext, (Object)distNext));
            this.editedLayersMap.put(layerNum, (Triple<Integer, Pair<WeightInit, Distribution>, Pair<WeightInit, Distribution>>)t);
            return this;
        }

        public Builder removeOutputLayer() {
            this.popN = 1;
            return this;
        }

        public Builder removeLayersFromOutput(int layerNum) {
            if (this.popN != 0) {
                throw new IllegalArgumentException("Remove layers from can only be called once");
            }
            this.popN = layerNum;
            return this;
        }

        public Builder addLayer(Layer layer) {
            if (!this.prepDone) {
                this.doPrep();
            }
            NeuralNetConfiguration layerConf = this.finetuneConfiguration.appliedNeuralNetConfigurationBuilder().layer(layer).build();
            int numParams = layer.initializer().numParams(layerConf);
            if (numParams > 0) {
                INDArray params = Nd4j.create((int)1, (int)numParams);
                org.deeplearning4j.nn.api.Layer someLayer = layer.instantiate(layerConf, null, 0, params, true);
                this.appendParams.add(someLayer.params());
                this.appendConfs.add(someLayer.conf());
            } else {
                this.appendConfs.add(layerConf);
            }
            return this;
        }

        public Builder setInputPreProcessor(int layer, InputPreProcessor processor) {
            this.inputPreProcessors.put(layer, processor);
            return this;
        }

        public MultiLayerNetwork build() {
            if (!this.prepDone) {
                this.doPrep();
            }
            this.editedModel = new MultiLayerNetwork(this.constructConf(), this.constructParams());
            if (this.frozenTill != -1) {
                org.deeplearning4j.nn.api.Layer[] layers = this.editedModel.getLayers();
                for (int i = this.frozenTill; i >= 0; --i) {
                    NeuralNetConfiguration origNNC = this.editedModel.getLayerWiseConfigurations().getConf(i);
                    NeuralNetConfiguration layerNNC = origNNC.clone();
                    this.editedModel.getLayerWiseConfigurations().getConf(i).resetVariables();
                    layers[i].setConf(layerNNC);
                    layers[i] = new org.deeplearning4j.nn.layers.FrozenLayer(layers[i]);
                    if (origNNC.getVariables() != null) {
                        List<String> vars = origNNC.variables(true);
                        origNNC.clearVariables();
                        layerNNC.clearVariables();
                        for (String s : vars) {
                            origNNC.variables(false).add(s);
                            origNNC.getL1ByParam().put(s, 0.0);
                            origNNC.getL2ByParam().put(s, 0.0);
                            layerNNC.variables(false).add(s);
                            layerNNC.getL1ByParam().put(s, 0.0);
                            layerNNC.getL2ByParam().put(s, 0.0);
                        }
                    }
                    Layer origLayerConf = this.editedModel.getLayerWiseConfigurations().getConf(i).getLayer();
                    FrozenLayer newLayerConf = new FrozenLayer(origLayerConf);
                    ((Layer)newLayerConf).setLayerName(origLayerConf.getLayerName());
                    this.editedModel.getLayerWiseConfigurations().getConf(i).setLayer(newLayerConf);
                }
                this.editedModel.setLayers(layers);
            }
            return this.editedModel;
        }

        private void doPrep() {
            int i;
            this.fineTuneConfigurationBuild();
            for (i = 0; i < this.origModel.getnLayers(); ++i) {
                if (this.origModel.getLayer(i).numParams() > 0) {
                    this.editedParams.add(this.origModel.getLayer(i).params().dup());
                    continue;
                }
                this.editedParams.add(this.origModel.getLayer(i).params());
            }
            if (!this.editedLayers.isEmpty()) {
                Object[] editedLayersSorted = this.editedLayers.toArray(new Integer[this.editedLayers.size()]);
                Arrays.sort(editedLayersSorted);
                for (int i2 = 0; i2 < editedLayersSorted.length; ++i2) {
                    int layerNum = (Integer)editedLayersSorted[i2];
                    this.nOutReplaceBuild(layerNum, (Integer)this.editedLayersMap.get(layerNum).getLeft(), (Pair<WeightInit, Distribution>)((Pair)this.editedLayersMap.get(layerNum).getMiddle()), (Pair<WeightInit, Distribution>)((Pair)this.editedLayersMap.get(layerNum).getRight()));
                }
            }
            for (i = 0; i < this.popN; ++i) {
                Integer layerNum = this.origModel.getnLayers() - i;
                if (this.inputPreProcessors.containsKey(layerNum)) {
                    this.inputPreProcessors.remove(layerNum);
                }
                this.editedConfs.remove(this.editedConfs.size() - 1);
                this.editedParams.remove(this.editedParams.size() - 1);
            }
            this.prepDone = true;
        }

        private void fineTuneConfigurationBuild() {
            for (int i = 0; i < this.origConf.getConfs().size(); ++i) {
                NeuralNetConfiguration layerConf;
                if (this.finetuneConfiguration != null) {
                    NeuralNetConfiguration nnc = this.origConf.getConf(i).clone();
                    this.finetuneConfiguration.applyToNeuralNetConfiguration(nnc);
                    layerConf = nnc;
                } else {
                    layerConf = this.origConf.getConf(i).clone();
                }
                this.editedConfs.add(layerConf);
            }
        }

        private void nOutReplaceBuild(int layerNum, int nOut, Pair<WeightInit, Distribution> schemedist, Pair<WeightInit, Distribution> schemedistNext) {
            NeuralNetConfiguration layerConf = this.editedConfs.get(layerNum);
            Layer layerImpl = layerConf.getLayer();
            FeedForwardLayer layerImplF = (FeedForwardLayer)layerImpl;
            layerImplF.setWeightInit((WeightInit)((Object)schemedist.getLeft()));
            layerImplF.setDist((Distribution)schemedist.getRight());
            layerImplF.setNOut(nOut);
            int numParams = layerImpl.initializer().numParams(layerConf);
            INDArray params = Nd4j.create((int)1, (int)numParams);
            org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf, null, 0, params, true);
            this.editedParams.set(layerNum, someLayer.params());
            if (layerNum + 1 < this.editedConfs.size()) {
                layerConf = this.editedConfs.get(layerNum + 1);
                layerImpl = layerConf.getLayer();
                layerImplF = (FeedForwardLayer)layerImpl;
                layerImplF.setWeightInit((WeightInit)((Object)schemedistNext.getLeft()));
                layerImplF.setDist((Distribution)schemedistNext.getRight());
                layerImplF.setNIn(nOut);
                numParams = layerImpl.initializer().numParams(layerConf);
                if (numParams > 0) {
                    params = Nd4j.create((int)1, (int)numParams);
                    someLayer = layerImpl.instantiate(layerConf, null, 0, params, true);
                    this.editedParams.set(layerNum + 1, someLayer.params());
                }
            }
        }

        private INDArray constructParams() {
            INDArray keepView = null;
            for (INDArray aParam : this.editedParams) {
                if (aParam == null) continue;
                if (keepView == null) {
                    keepView = aParam;
                    continue;
                }
                keepView = Nd4j.hstack((INDArray[])new INDArray[]{keepView, aParam});
            }
            if (!this.appendParams.isEmpty()) {
                INDArray appendView = Nd4j.hstack(this.appendParams);
                return Nd4j.hstack((INDArray[])new INDArray[]{keepView, appendView});
            }
            return keepView;
        }

        private MultiLayerConfiguration constructConf() {
            ArrayList<NeuralNetConfiguration> allConfs = new ArrayList<NeuralNetConfiguration>();
            allConfs.addAll(this.editedConfs);
            allConfs.addAll(this.appendConfs);
            for (int i = 0; i < allConfs.size(); ++i) {
                if (((NeuralNetConfiguration)allConfs.get(i)).getLayer().getLayerName() != null) continue;
                ((NeuralNetConfiguration)allConfs.get(i)).getLayer().setLayerName("layer" + i);
            }
            MultiLayerConfiguration conf = new MultiLayerConfiguration.Builder().inputPreProcessors(this.inputPreProcessors).setInputType(this.inputType).confs(allConfs).build();
            if (this.finetuneConfiguration != null) {
                this.finetuneConfiguration.applyToMultiLayerConfiguration(conf);
            }
            return conf;
        }
    }
}

