/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn;

import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.models.featuredetectors.autoencoder.AutoEncoder;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.Persistable;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.Layer;
import org.deeplearning4j.nn.layers.OutputLayer;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.deeplearning4j.optimize.optimizers.BackPropOptimizer;
import org.deeplearning4j.optimize.optimizers.BackPropROptimizer;
import org.deeplearning4j.optimize.optimizers.MultiLayerNetworkOptimizer;
import org.deeplearning4j.util.Dl4jReflection;
import org.deeplearning4j.util.SerializationUtils;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.sampling.Sampling;
import org.nd4j.linalg.transformation.MatrixTransform;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseMultiLayerNetwork
implements Serializable,
Persistable,
Classifier {
    private static Logger log = LoggerFactory.getLogger(BaseMultiLayerNetwork.class);
    private static final long serialVersionUID = -5029161847383716484L;
    protected int[] hiddenLayerSizes;
    protected Layer[] layers;
    protected INDArray input;
    protected INDArray labels;
    protected MultiLayerNetworkOptimizer optimizer;
    protected Map<Integer, MatrixTransform> weightTransforms = new HashMap<Integer, MatrixTransform>();
    protected Map<Integer, MatrixTransform> hiddenBiasTransforms = new HashMap<Integer, MatrixTransform>();
    protected Map<Integer, MatrixTransform> visibleBiasTransforms = new HashMap<Integer, MatrixTransform>();
    protected boolean forceNumEpochs = false;
    protected boolean initCalled = false;
    protected boolean sampleFromHiddenActivations = true;
    protected boolean pretrain = true;
    protected NeuralNetConfiguration defaultConfiguration;
    protected List<NeuralNetConfiguration> layerWiseConfigurations;
    protected double learningRateUpdate = 0.95f;
    protected NeuralNetwork[] neuralNets;
    protected double errorTolerance = 1.0E-4f;
    protected INDArray mask;
    protected boolean useDropConnect = false;
    protected double dampingFactor = 10.0;
    protected boolean useGaussNewtonVectorProductBackProp = false;

    protected BaseMultiLayerNetwork() {
    }

    protected BaseMultiLayerNetwork(int[] hiddenLayerSizes, int nLayers) {
        this(hiddenLayerSizes, nLayers, null, null);
    }

    protected BaseMultiLayerNetwork(int[] hiddenLayerSizes, int nLayers, INDArray input, INDArray labels) {
        this.hiddenLayerSizes = hiddenLayerSizes;
        this.input = input.dup();
        this.labels = labels.dup();
        if (hiddenLayerSizes.length != nLayers) {
            throw new IllegalArgumentException("The number of hidden layer sizes must be equivalent to the nLayers argument which is a value of " + nLayers);
        }
        this.setnLayers(nLayers);
        this.layers = new org.deeplearning4j.nn.layers.Layer[nLayers + 1];
        this.intializeConfigurations();
        if (input != null) {
            this.initializeLayers(input);
        }
    }

    protected void intializeConfigurations() {
        if (this.layerWiseConfigurations == null) {
            this.layerWiseConfigurations = new ArrayList<NeuralNetConfiguration>();
        }
        if (this.layers == null) {
            this.layers = new Layer[this.getnLayers() + 1];
        }
        if (this.neuralNets == null) {
            this.neuralNets = new NeuralNetwork[this.getnLayers()];
        }
        if (this.defaultConfiguration == null) {
            this.defaultConfiguration = new NeuralNetConfiguration.Builder().build();
        }
        if (this.layerWiseConfigurations == null || this.layerWiseConfigurations.isEmpty()) {
            for (int i = 0; i < this.hiddenLayerSizes.length + 1; ++i) {
                this.layerWiseConfigurations.add(this.defaultConfiguration.clone());
            }
        }
    }

    public void dimensionCheck() {
        assert (this.layers.length == this.neuralNets.length + 1) : "Missing output layer";
        for (int i = 0; i < this.getnLayers(); ++i) {
            Layer h = this.layers[i];
            NeuralNetwork network = this.neuralNets[i];
            LinAlgExceptions.assertSameShape((INDArray)network.getW(), (INDArray)h.getW());
            LinAlgExceptions.assertSameShape((INDArray)h.getB(), (INDArray)network.gethBias());
            assert (h.conf().getnIn() == h.getW().rows()) : "Number of inputs not consistent with number of rows in weight matrix";
            assert (h.conf().getnOut() == h.getW().columns()) : "Number of inputs not consistent with number of rows in weight matrix";
            if (i >= this.getnLayers() - 1) continue;
            Layer h1 = this.layers[i + 1];
            NeuralNetwork network1 = this.neuralNets[i + 1];
            assert (h1.conf().getnIn() == h1.getW().rows()) : "Number of inputs not consistent with number of rows in weight matrix";
            assert (h1.conf().getnOut() == h1.getW().columns()) : "Number of inputs not consistent with number of rows in weight matrix";
            assert (network1.conf().getnIn() == network1.getW().rows()) : "Number of inputs not consistent with number of rows in weight matrix";
            assert (network1.conf().getnOut() == network1.getW().columns()) : "Number of inputs not consistent with number of rows in weight matrix";
            if (h1.conf().getnIn() != h.conf().getnOut()) {
                throw new IllegalStateException("Invalid structure: hidden layer in for " + (i + 1) + " not equal to number of ins " + i);
            }
            if (network.conf().getnOut() == network1.conf().getnIn()) continue;
            throw new IllegalStateException("Invalid structure: network hidden for " + (i + 1) + " not equal to number of visible " + i);
        }
    }

    @Override
    public INDArray transform(INDArray data) {
        return this.output(data);
    }

    public NeuralNetConfiguration getDefaultConfiguration() {
        return this.defaultConfiguration;
    }

    public void setDefaultConfiguration(NeuralNetConfiguration defaultConfiguration) {
        this.defaultConfiguration = defaultConfiguration;
    }

    public List<NeuralNetConfiguration> getLayerWiseConfigurations() {
        return this.layerWiseConfigurations;
    }

    public void setLayerWiseConfigurations(List<NeuralNetConfiguration> layerWiseConfigurations) {
        this.layerWiseConfigurations = layerWiseConfigurations;
    }

    public void initializeLayers(INDArray input) {
        if (input == null) {
            throw new IllegalArgumentException("Unable to initialize neuralNets with empty input");
        }
        if (input.columns() != this.defaultConfiguration.getnIn()) {
            throw new IllegalArgumentException(String.format("Unable to iterate on number of inputs; columns should be equal to number of inputs. Number of inputs was %d while number of columns was %d", this.defaultConfiguration.getnIn(), input.columns()));
        }
        if (this.neuralNets == null) {
            this.neuralNets = new NeuralNetwork[this.getnLayers()];
        }
        for (int i = 0; i < this.hiddenLayerSizes.length; ++i) {
            if (this.hiddenLayerSizes[i] >= 1) continue;
            throw new IllegalArgumentException("All hidden layer sizes must be >= 1");
        }
        this.input = input.dup();
        if (!this.initCalled) {
            this.init();
            log.info("Initializing neuralNets with input of dims " + input.rows() + " x " + input.columns());
        }
    }

    public void init() {
        if (this.layerWiseConfigurations == null || this.layers == null) {
            this.intializeConfigurations();
        }
        INDArray layerInput = this.input;
        if (this.getnLayers() < 1) {
            throw new IllegalStateException("Unable to createComplex network neuralNets; number specified is less than 1");
        }
        if (this.neuralNets == null || this.neuralNets == null || this.neuralNets[0] == null || this.neuralNets[0] == null) {
            this.neuralNets = new NeuralNetwork[this.getnLayers()];
            for (int i = 0; i < this.getnLayers(); ++i) {
                int inputSize = i == 0 ? this.defaultConfiguration.getnIn() : this.hiddenLayerSizes[i - 1];
                if (i == 0) {
                    this.layerWiseConfigurations.get(i).setnIn(inputSize);
                    this.layerWiseConfigurations.get(i).setnOut(this.hiddenLayerSizes[i]);
                    this.layers[i] = this.createHiddenLayer(i, layerInput);
                } else {
                    if (this.input != null) {
                        layerInput = this.activationFromPrevLayer(i - 1, layerInput);
                    }
                    this.layerWiseConfigurations.get(i).setnIn(inputSize);
                    this.layerWiseConfigurations.get(i).setnOut(this.hiddenLayerSizes[i]);
                    this.layers[i] = this.createHiddenLayer(i, layerInput);
                }
                this.neuralNets[i] = this.createLayer(layerInput, this.layers[i].getW(), this.layers[i].getB(), null, i);
            }
        }
        NeuralNetConfiguration last = this.layerWiseConfigurations.get(this.layerWiseConfigurations.size() - 1);
        NeuralNetConfiguration secondToLast = this.layerWiseConfigurations.get(this.layerWiseConfigurations.size() - 2);
        last.setnIn(secondToLast.getnOut());
        this.layers[this.layers.length - 1] = new OutputLayer.Builder().configure(last).build();
        this.dimensionCheck();
        this.applyTransforms();
        this.initCalled = true;
        this.initMask();
    }

    public INDArray activate() {
        return this.getLayers()[this.getNeuralNets().length - 1].activate();
    }

    public INDArray activate(int layer) {
        return this.getLayers()[layer].activate();
    }

    public INDArray activate(int layer, INDArray input) {
        return this.getLayers()[layer].activate(input);
    }

    public void finetune() {
        this.finetune(this.labels);
    }

    public void initialize(org.nd4j.linalg.dataset.DataSet data) {
        this.setInput(data.getFeatureMatrix());
        this.feedForward(data.getFeatureMatrix());
        this.labels = data.getLabels();
        this.getOutputLayer().setLabels(this.labels);
    }

    public INDArray activationFromPrevLayer(int curr, INDArray input) {
        if (curr == this.neuralNets.length) {
            return this.getOutputLayer().labelProbabilities(input);
        }
        switch (this.layers[curr].conf().getActivationType()) {
            case HIDDEN_LAYER_ACTIVATION: {
                return this.layers[curr].activate(input);
            }
            case NET_ACTIVATION: {
                return this.neuralNets[curr].hiddenActivation(input);
            }
            case SAMPLE: {
                return this.neuralNets[curr].sampleHiddenGivenVisible(input).getSecond();
            }
        }
        throw new IllegalStateException("Invalid activation type");
    }

    public List<INDArray> feedForward() {
        INDArray currInput = this.input;
        if (this.input.columns() != this.defaultConfiguration.getnIn()) {
            throw new IllegalStateException("Illegal input length");
        }
        ArrayList<INDArray> activations = new ArrayList<INDArray>();
        activations.add(currInput);
        for (int i = 0; i < this.layers.length; ++i) {
            currInput = this.activationFromPrevLayer(i, currInput);
            this.applyDropConnectIfNecessary(currInput);
            activations.add(currInput);
        }
        return activations;
    }

    public List<INDArray> feedForward(INDArray input) {
        if (input == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        this.input = input;
        return this.feedForward();
    }

    protected void applyDropConnectIfNecessary(INDArray input) {
        if (this.useDropConnect) {
            INDArray mask = Sampling.binomial((INDArray)Nd4j.valueArrayOf((int)input.rows(), (int)input.columns(), (double)0.5), (int)1, (RandomGenerator)this.defaultConfiguration.getRng());
            input.muli(mask);
            if (this.defaultConfiguration.getL2() > 0.0f) {
                input.muli((Number)Float.valueOf(this.defaultConfiguration.getL2()));
            }
        }
    }

    protected List<INDArray> computeDeltasR(INDArray v) {
        int i;
        ArrayList<INDArray> deltaRet = new ArrayList<INDArray>();
        INDArray[] deltas = new INDArray[this.getnLayers() + 1];
        List<INDArray> activations = this.feedForward();
        List<INDArray> rActivations = this.feedForwardR(activations, v);
        ArrayList<INDArray> weights = new ArrayList<INDArray>();
        ArrayList<INDArray> biases = new ArrayList<INDArray>();
        ArrayList<ActivationFunction> activationFunctions = new ArrayList<ActivationFunction>();
        for (int j = 0; j < this.getLayers().length; ++j) {
            weights.add(this.getLayers()[j].getW());
            biases.add(this.getLayers()[j].getB());
            activationFunctions.add(this.getLayers()[j].conf().getActivationFunction());
        }
        INDArray rix = rActivations.get(rActivations.size() - 1).divi((Number)this.input.rows());
        LinAlgExceptions.assertValidNum((INDArray)rix);
        for (i = this.getnLayers(); i >= 0; --i) {
            deltas[i] = activations.get(i).transpose().mmul(rix);
            this.applyDropConnectIfNecessary(deltas[i]);
            if (i <= 0) continue;
            rix = rix.mmul(((INDArray)weights.get(i)).addRowVector((INDArray)biases.get(i)).transpose()).muli(((ActivationFunction)activationFunctions.get(i - 1)).applyDerivative(activations.get(i)));
        }
        for (i = 0; i < deltas.length; ++i) {
            if (this.defaultConfiguration.isConstrainGradientToUnitNorm()) {
                double sum = deltas[i].sum(Integer.MAX_VALUE).getDouble(0);
                if (sum > 0.0) {
                    deltaRet.add(deltas[i].div(deltas[i].norm2(Integer.MAX_VALUE)));
                }
            } else {
                deltaRet.add(deltas[i]);
            }
            LinAlgExceptions.assertValidNum((INDArray)((INDArray)deltaRet.get(i)));
        }
        return deltaRet;
    }

    public void dampingUpdate(double rho, double boost, double decrease) {
        if (rho < 0.25 || Double.isNaN(rho)) {
            this.dampingFactor *= boost;
        } else if (rho > 0.75) {
            this.dampingFactor *= decrease;
        }
    }

    public double reductionRatio(INDArray p, double currScore, double score, INDArray gradient) {
        double currentDamp = this.dampingFactor;
        this.dampingFactor = 0.0;
        INDArray denom = this.getBackPropRGradient(p);
        denom.muli((Number)0.5).muli(p.mul(denom)).sum(0);
        denom.subi(gradient.mul(p).sum(0));
        double rho = (currScore - score) / (Double)denom.getScalar(0).element();
        this.dampingFactor = currentDamp;
        if (score - currScore > 0.0) {
            return Double.NEGATIVE_INFINITY;
        }
        return rho;
    }

    protected List<Pair<INDArray, INDArray>> computeDeltas2() {
        int i;
        ArrayList<Pair<INDArray, INDArray>> deltaRet = new ArrayList<Pair<INDArray, INDArray>>();
        List<INDArray> activations = this.feedForward();
        INDArray[] deltas = new INDArray[activations.size() - 1];
        INDArray[] preCons = new INDArray[activations.size() - 1];
        INDArray ix = activations.get(activations.size() - 1).sub(this.labels).div((Number)this.labels.rows());
        log.info("Ix mean " + ix.sum(Integer.MAX_VALUE));
        ArrayList<INDArray> weights = new ArrayList<INDArray>();
        ArrayList<INDArray> biases = new ArrayList<INDArray>();
        ArrayList<ActivationFunction> activationFunctions = new ArrayList<ActivationFunction>();
        for (int j = 0; j < this.getLayers().length; ++j) {
            weights.add(this.getLayers()[j].getW());
            biases.add(this.getLayers()[j].getB());
            activationFunctions.add(this.getLayers()[j].conf().getActivationFunction());
        }
        for (i = weights.size() - 1; i >= 0; --i) {
            deltas[i] = activations.get(i).transpose().mmul(ix);
            log.info("Delta sum at " + i + " is " + deltas[i].sum(Integer.MAX_VALUE));
            preCons[i] = Transforms.pow((INDArray)activations.get(i).transpose(), (Number)2).mmul(Transforms.pow((INDArray)ix, (Number)2)).mul((Number)this.labels.rows());
            this.applyDropConnectIfNecessary(deltas[i]);
            if (i <= 0) continue;
            ix = ix.mmul(((INDArray)weights.get(i)).transpose()).muli(((ActivationFunction)activationFunctions.get(i - 1)).applyDerivative(activations.get(i)));
        }
        for (i = 0; i < deltas.length; ++i) {
            if (this.defaultConfiguration.isConstrainGradientToUnitNorm()) {
                deltaRet.add(new Pair<INDArray, INDArray>(deltas[i].divi(deltas[i].norm2(Integer.MAX_VALUE)), preCons[i]));
                continue;
            }
            deltaRet.add(new Pair<INDArray, INDArray>(deltas[i], preCons[i]));
        }
        return deltaRet;
    }

    @Override
    public void setParams(INDArray params) {
        int start = 0;
        for (int i = 0; i < this.neuralNets.length; ++i) {
            int numParams = this.getNeuralNets()[i].numParams();
            this.getNeuralNets()[i].setParams(params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)start, (int)(start + numParams))}));
            this.getLayers()[i].setW(this.getNeuralNets()[i].getW());
            this.getLayers()[i].setB(this.getNeuralNets()[i].gethBias());
            start += numParams;
        }
        this.getOutputLayer().setParams(params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)start, (int)params.length())}));
    }

    @Override
    public void fit(INDArray data) {
        this.pretrain(data, new Object[]{1});
    }

    protected List<INDArray> computeDeltas() {
        int i;
        ArrayList<INDArray> deltaRet = new ArrayList<INDArray>();
        INDArray[] deltas = new INDArray[this.getnLayers() + 2];
        List<INDArray> activations = this.feedForward();
        INDArray ix = this.labels.sub(activations.get(activations.size() - 1)).subi(this.getOutputLayer().conf().getActivationFunction().applyDerivative(activations.get(activations.size() - 1)));
        ArrayList<INDArray> weights = new ArrayList<INDArray>();
        ArrayList<INDArray> biases = new ArrayList<INDArray>();
        ArrayList<ActivationFunction> activationFunctions = new ArrayList<ActivationFunction>();
        for (int j = 0; j < this.getNeuralNets().length; ++j) {
            weights.add(this.getNeuralNets()[j].getW());
            biases.add(this.getNeuralNets()[j].gethBias());
            activationFunctions.add(this.getLayers()[j].conf().getActivationFunction());
        }
        weights.add(this.getOutputLayer().getW());
        biases.add(this.getOutputLayer().getB());
        activationFunctions.add(this.getOutputLayer().conf().getActivationFunction());
        for (i = this.getnLayers() + 1; i >= 0; --i) {
            INDArray delta;
            if (i >= this.getnLayers() + 1) {
                deltas[i] = ix;
                continue;
            }
            deltas[i] = delta = activations.get(i).transpose().mmul(ix);
            this.applyDropConnectIfNecessary(deltas[i]);
            INDArray weightsPlusBias = ((INDArray)weights.get(i)).transpose();
            INDArray activation = activations.get(i);
            if (i <= 0) continue;
            ix = ix.mmul(weightsPlusBias).muli(((ActivationFunction)activationFunctions.get(i - 1)).applyDerivative(activation));
        }
        for (i = 0; i < deltas.length; ++i) {
            if (this.defaultConfiguration.isConstrainGradientToUnitNorm()) {
                deltaRet.add(deltas[i].divi(deltas[i].norm2(Integer.MAX_VALUE)));
                continue;
            }
            deltaRet.add(deltas[i]);
        }
        return deltaRet;
    }

    public void backPropStep() {
        List<Pair<INDArray, INDArray>> deltas = this.backPropGradient();
        for (int i = 0; i < this.layers.length; ++i) {
            this.layers[i].getW().subi(deltas.get(i).getFirst());
            this.layers[i].getB().subi(deltas.get(i).getSecond());
            if (i >= this.neuralNets.length) continue;
            this.neuralNets[i].setW(this.layers[i].getW());
            this.neuralNets[i].sethBias(this.layers[i].getB());
        }
    }

    public void backPropStepR(INDArray v) {
        List<Pair<INDArray, INDArray>> deltas = this.backPropGradientR(v);
        for (int i = 0; i < this.neuralNets.length; ++i) {
            if (deltas.size() >= this.neuralNets.length) continue;
            this.neuralNets[i].getW().subi(deltas.get(i).getFirst());
            this.neuralNets[i].gethBias().subi(deltas.get(i).getSecond());
            this.layers[i].setW(this.neuralNets[i].getW());
            this.layers[i].setB(this.neuralNets[i].gethBias());
        }
        this.getOutputLayer().getW().subi(deltas.get(deltas.size() - 1).getFirst());
        this.getOutputLayer().getB().subi(deltas.get(deltas.size() - 1).getSecond());
    }

    public Layer[] getLayers() {
        return this.layers;
    }

    public void setLayers(Layer[] layers) {
        this.layers = layers;
    }

    public void setNeuralNets(NeuralNetwork[] neuralNets) {
        this.neuralNets = neuralNets;
    }

    public INDArray getBackPropRGradient(INDArray v) {
        return this.pack(this.backPropGradientR(v));
    }

    public Pair<INDArray, INDArray> getBackPropGradient2() {
        List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> deltas = this.backPropGradient2();
        ArrayList<Pair<INDArray, INDArray>> deltaNormal = new ArrayList<Pair<INDArray, INDArray>>();
        ArrayList<Pair<INDArray, INDArray>> deltasPreCon = new ArrayList<Pair<INDArray, INDArray>>();
        for (int i = 0; i < deltas.size(); ++i) {
            deltaNormal.add(deltas.get(i).getFirst());
            deltasPreCon.add(deltas.get(i).getSecond());
        }
        return new Pair<INDArray, INDArray>(this.pack(deltaNormal), this.pack(deltasPreCon));
    }

    public BaseMultiLayerNetwork clone() {
        BaseMultiLayerNetwork ret = null;
        try {
            ret = (BaseMultiLayerNetwork)this.getClass().newInstance();
        }
        catch (InstantiationException e) {
            e.printStackTrace();
        }
        catch (IllegalAccessException e) {
            e.printStackTrace();
        }
        ret.update(this);
        return ret;
    }

    public List<INDArray> weightMatrices() {
        ArrayList<INDArray> ret = new ArrayList<INDArray>();
        for (int i = 0; i < this.neuralNets.length; ++i) {
            ret.add(this.neuralNets[i].getW());
        }
        ret.add(this.getOutputLayer().getW());
        return ret;
    }

    public List<INDArray> feedForwardR(List<INDArray> acts, INDArray v) {
        ArrayList<INDArray> R = new ArrayList<INDArray>();
        R.add(Nd4j.zeros((int)this.input.rows(), (int)this.input.columns()));
        List<Pair<INDArray, INDArray>> vWvB = this.unPack(v);
        List<INDArray> W = this.weightMatrices();
        for (int i = 0; i < this.neuralNets.length; ++i) {
            ActivationFunction derivative = this.getNeuralNets()[i].conf().getActivationFunction();
            if (this.getNeuralNets()[i] instanceof AutoEncoder) {
                AutoEncoder a = (AutoEncoder)this.getNeuralNets()[i];
                derivative = a.conf.getActivationFunction();
            }
            R.add(((INDArray)R.get(i)).mmul(W.get(i)).addi(acts.get(i).mmul(vWvB.get(i).getFirst().addRowVector(vWvB.get(i).getSecond()))).muli(derivative.applyDerivative(acts.get(i + 1))));
        }
        R.add(((INDArray)R.get(R.size() - 1)).mmul(W.get(W.size() - 1)).addi(acts.get(acts.size() - 2).mmul(vWvB.get(vWvB.size() - 1).getFirst().addRowVector(vWvB.get(vWvB.size() - 1).getSecond()))).muli(this.getOutputLayer().conf().getActivationFunction().applyDerivative(acts.get(acts.size() - 1))));
        return R;
    }

    public List<INDArray> feedForwardR(INDArray v) {
        return this.feedForwardR(this.feedForward(), v);
    }

    public void backProp(TrainingEvaluator eval) {
        if (this.useGaussNewtonVectorProductBackProp) {
            BackPropROptimizer opt = new BackPropROptimizer(this, this.defaultConfiguration.getLr(), this.defaultConfiguration.getNumIterations());
            opt.optimize(eval);
        } else {
            BackPropOptimizer opt = new BackPropOptimizer(this, this.defaultConfiguration.getLr(), this.defaultConfiguration.getNumIterations());
            opt.optimize(eval);
        }
    }

    public void backProp() {
        this.backProp(null);
    }

    public boolean isUseDropConnect() {
        return this.useDropConnect;
    }

    public void setUseDropConnect(boolean useDropConnect) {
        this.useDropConnect = useDropConnect;
    }

    @Override
    public INDArray params() {
        ArrayList<INDArray> params = new ArrayList<INDArray>();
        for (int i = 0; i < this.getnLayers(); ++i) {
            params.add(this.getNeuralNets()[i].getW());
            params.add(this.getNeuralNets()[i].gethBias());
        }
        params.add(this.getOutputLayer().params());
        return Nd4j.toFlattened(params);
    }

    @Override
    public int numParams() {
        int length = 0;
        for (int i = 0; i < this.neuralNets.length; ++i) {
            length += this.neuralNets[i].numParams() - this.neuralNets[i].getvBias().length();
        }
        return length += this.getOutputLayer().numParams();
    }

    public INDArray pack() {
        ArrayList<Pair<INDArray, INDArray>> vWvB = new ArrayList<Pair<INDArray, INDArray>>();
        for (int i = 0; i < this.neuralNets.length; ++i) {
            vWvB.add(new Pair<INDArray, INDArray>(this.neuralNets[i].getW(), this.neuralNets[i].gethBias()));
        }
        vWvB.add(new Pair<INDArray, INDArray>(this.getOutputLayer().getW(), this.getOutputLayer().getB()));
        return this.pack(vWvB);
    }

    public INDArray pack(List<Pair<INDArray, INDArray>> layers) {
        if (layers.size() != this.neuralNets.length + 1) {
            throw new IllegalArgumentException("Illegal number of neuralNets passed in. Was " + layers.size() + " when should have been " + (this.neuralNets.length + 1));
        }
        ArrayList<INDArray> list = new ArrayList<INDArray>();
        for (int i = 0; i < layers.size(); ++i) {
            list.add(layers.get(i).getFirst());
            list.add(layers.get(i).getSecond());
        }
        INDArray ret = Nd4j.toFlattened(list);
        if (ret.length() != this.numParams()) {
            throw new IllegalStateException("Illegal number of parameters found in the layers with a difference of " + Math.abs(ret.length() - this.numParams()));
        }
        return ret;
    }

    @Override
    public double score(DataSet data) {
        return this.score(data.getFeatureMatrix(), data.getLabels());
    }

    public List<Pair<INDArray, INDArray>> backPropGradient() {
        List<INDArray> deltas = this.computeDeltas();
        ArrayList<Pair<INDArray, INDArray>> vWvB = new ArrayList<Pair<INDArray, INDArray>>();
        for (int i = 0; i < this.neuralNets.length; ++i) {
            vWvB.add(new Pair<INDArray, INDArray>(this.neuralNets[i].getW(), this.neuralNets[i].gethBias()));
        }
        vWvB.add(new Pair<INDArray, INDArray>(this.getOutputLayer().getW(), this.getOutputLayer().getB()));
        List<Pair<INDArray, INDArray>> list = new ArrayList<Pair<INDArray, INDArray>>();
        for (int l = 0; l < this.getnLayers() + 1; ++l) {
            INDArray gradientChange = deltas.get(l);
            if (gradientChange.length() != this.getLayers()[l].getW().length()) {
                throw new IllegalStateException("Gradient change not equal to weight change");
            }
            INDArray deltaColumnSums = deltas.get(l).isVector() ? deltas.get(l) : deltas.get(l).sum(0);
            list.add(new Pair<INDArray, INDArray>(gradientChange, deltaColumnSums));
        }
        if (this.mask == null) {
            this.initMask();
        }
        for (int i = 0; i < list.size(); ++i) {
            INDArray biasGradientForLayer;
            INDArray weightGradientForLayer;
            if (i < this.getnLayers()) {
                weightGradientForLayer = (INDArray)((Pair)list.get(i)).getFirst();
                biasGradientForLayer = (INDArray)((Pair)list.get(i)).getSecond();
                assert (Arrays.equals(weightGradientForLayer.shape(), this.getNeuralNets()[i].getW().shape())) : "Illegal shape for layer " + i + " weight gradient, should have been " + Arrays.toString(this.getNeuralNets()[i].getW().shape()) + " but was " + Arrays.toString(weightGradientForLayer.shape());
                assert (Arrays.equals(biasGradientForLayer.shape(), this.getNeuralNets()[i].gethBias().shape())) : "Illegal shape for layer " + i + " bias   gradient, should have been " + Arrays.toString(this.getNeuralNets()[i].gethBias().shape()) + " but was " + Arrays.toString(biasGradientForLayer.shape());
                continue;
            }
            weightGradientForLayer = (INDArray)((Pair)list.get(i)).getFirst();
            biasGradientForLayer = (INDArray)((Pair)list.get(i)).getSecond();
            assert (Arrays.equals(weightGradientForLayer.shape(), this.getOutputLayer().getW().shape())) : "Illegal shape for output  weight gradient, should have been " + Arrays.toString(this.getOutputLayer().getW().shape()) + " but was " + Arrays.toString(weightGradientForLayer.shape());
            assert (Arrays.equals(biasGradientForLayer.shape(), this.getOutputLayer().getB().shape())) : "Illegal shape for output layer  bias   gradient, should have been " + Arrays.toString(this.getOutputLayer().getB().shape()) + " but was " + Arrays.toString(biasGradientForLayer.shape());
        }
        INDArray gradient = this.pack(list);
        INDArray params = this.params().mul((Number)Float.valueOf(this.defaultConfiguration.getL2()));
        list = this.unPack(gradient);
        return list;
    }

    public List<Pair<INDArray, INDArray>> unPack(INDArray param) {
        int numParams = this.numParams();
        if (param.length() != numParams) {
            throw new IllegalArgumentException("Parameter vector not equal of length to " + numParams);
        }
        if (param.rows() != 1) {
            param = param.reshape(1, param.length());
        }
        ArrayList<Pair<INDArray, INDArray>> ret = new ArrayList<Pair<INDArray, INDArray>>();
        int curr = 0;
        for (int i = 0; i < this.layers.length; ++i) {
            int layerLength = this.layers[i].getW().length() + this.layers[i].getB().length();
            INDArray subMatrix = param.get(new NDArrayIndex[]{NDArrayIndex.interval((int)curr, (int)(curr + layerLength))});
            INDArray weightPortion = subMatrix.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)this.layers[i].getW().length())});
            int beginHBias = this.layers[i].getW().length();
            int endHbias = subMatrix.length();
            INDArray hBiasPortion = subMatrix.get(new NDArrayIndex[]{NDArrayIndex.interval((int)beginHBias, (int)endHbias)});
            int layerLengthSum = weightPortion.length() + hBiasPortion.length();
            if (layerLengthSum != layerLength) {
                if (hBiasPortion.length() != this.layers[i].getB().length()) {
                    throw new IllegalStateException("Hidden bias on layer " + i + " was off");
                }
                if (weightPortion.length() != this.layers[i].getW().length()) {
                    throw new IllegalStateException("Weight portion on layer " + i + " was off");
                }
            }
            ret.add(new Pair<INDArray, INDArray>(weightPortion.reshape(this.layers[i].getW().rows(), this.layers[i].getW().columns()), hBiasPortion.reshape(this.layers[i].getB().rows(), this.layers[i].getB().columns())));
            curr += layerLength;
        }
        return ret;
    }

    protected List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> backPropGradient2() {
        List<Pair<INDArray, INDArray>> deltas = this.computeDeltas2();
        ArrayList<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> list = new ArrayList<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>>();
        ArrayList<Pair<INDArray, INDArray>> grad = new ArrayList<Pair<INDArray, INDArray>>();
        ArrayList<Pair<INDArray, INDArray>> preCon = new ArrayList<Pair<INDArray, INDArray>>();
        for (int l = 0; l < deltas.size(); ++l) {
            INDArray gradientChange = deltas.get(l).getFirst();
            INDArray preConGradientChange = deltas.get(l).getSecond();
            if (l < this.layers.length && gradientChange.length() != this.layers[l].getW().length()) {
                throw new IllegalStateException("Gradient change not equal to weight change");
            }
            if (l == this.getNeuralNets().length && gradientChange.length() != this.getOutputLayer().getW().length()) {
                throw new IllegalStateException("Gradient change not equal to weight change");
            }
            INDArray deltaColumnSums = deltas.get(l).getFirst().mean(0);
            INDArray preConColumnSums = deltas.get(l).getSecond().mean(0);
            grad.add(new Pair<INDArray, INDArray>(gradientChange, deltaColumnSums));
            preCon.add(new Pair<INDArray, INDArray>(preConGradientChange, preConColumnSums));
            if (l < this.layers.length && deltaColumnSums.length() != this.layers[l].getB().length()) {
                throw new IllegalStateException("Bias change not equal to weight change");
            }
            if (l != this.getLayers().length || deltaColumnSums.length() == this.getOutputLayer().getB().length()) continue;
            throw new IllegalStateException("Bias change not equal to weight change");
        }
        INDArray g = this.pack(grad);
        INDArray con = this.pack(preCon);
        INDArray theta = this.params();
        if (this.mask == null) {
            this.initMask();
        }
        g.addi(theta.mul((Number)Float.valueOf(this.defaultConfiguration.getL2())).muli(this.mask));
        INDArray conAdd = Transforms.pow((INDArray)this.mask.mul((Number)Float.valueOf(this.defaultConfiguration.getL2())).add(Nd4j.valueArrayOf((int)g.rows(), (int)g.columns(), (double)this.dampingFactor)), (Number)0.75);
        con.addi(conAdd);
        List<Pair<INDArray, INDArray>> gUnpacked = this.unPack(g);
        List<Pair<INDArray, INDArray>> conUnpacked = this.unPack(con);
        for (int i = 0; i < gUnpacked.size(); ++i) {
            list.add(new Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>(gUnpacked.get(i), conUnpacked.get(i)));
        }
        return list;
    }

    protected List<Pair<INDArray, INDArray>> backPropGradientR(INDArray v) {
        if (this.mask == null) {
            this.initMask();
        }
        List<INDArray> deltas = this.computeDeltasR(v);
        ArrayList<Pair<INDArray, INDArray>> list = new ArrayList<Pair<INDArray, INDArray>>();
        for (int l = 0; l < this.getnLayers(); ++l) {
            INDArray gradientChange = deltas.get(l);
            if (gradientChange.length() != this.getNeuralNets()[l].getW().length()) {
                throw new IllegalStateException("Gradient change not equal to weight change");
            }
            INDArray deltaColumnSums = deltas.get(l).mean(0);
            if (deltaColumnSums.length() != this.layers[l].getB().length()) {
                throw new IllegalStateException("Bias change not equal to weight change");
            }
            list.add(new Pair<INDArray, INDArray>(gradientChange, deltaColumnSums));
        }
        INDArray logLayerGradient = deltas.get(this.getnLayers());
        INDArray biasGradient = deltas.get(this.getnLayers()).mean(0);
        list.add(new Pair<INDArray, INDArray>(logLayerGradient, biasGradient));
        INDArray pack = this.pack(list).addi(this.mask.mul((Number)Float.valueOf(this.defaultConfiguration.getL2())).mul(v)).addi(v.mul((Number)this.dampingFactor));
        return this.unPack(pack);
    }

    public void finetune(DataSetIterator iter, double lr) {
        org.nd4j.linalg.dataset.DataSet data;
        iter.reset();
        while (iter.hasNext() && (data = (org.nd4j.linalg.dataset.DataSet)iter.next()).getFeatureMatrix() != null && data.getLabels() != null) {
            this.setInput(data.getFeatureMatrix());
            this.setLabels(data.getLabels());
            this.feedForward();
            this.optimizer = new MultiLayerNetworkOptimizer(this, lr);
            this.optimizer.optimize(data.getLabels());
        }
    }

    public void finetune(DataSetIterator iter, double lr, int iterations, TrainingEvaluator eval) {
        org.nd4j.linalg.dataset.DataSet data;
        iter.reset();
        while (iter.hasNext() && (data = (org.nd4j.linalg.dataset.DataSet)iter.next()).getFeatureMatrix() != null && data.getLabels() != null) {
            this.setInput(data.getFeatureMatrix());
            this.setLabels(data.getLabels());
            this.optimizer = new MultiLayerNetworkOptimizer(this, lr);
            this.optimizer.optimize(this.labels, eval);
        }
    }

    public void finetune(INDArray labels) {
        this.labels = labels;
        if (labels != null) {
            this.labels = labels;
        }
        this.optimizer = new MultiLayerNetworkOptimizer(this, this.defaultConfiguration.getLr());
        this.optimizer.optimize(this.labels);
    }

    public void finetune(INDArray labels, TrainingEvaluator eval) {
        this.feedForward();
        if (labels != null) {
            this.labels = labels;
        }
        this.optimizer = new MultiLayerNetworkOptimizer(this, this.defaultConfiguration.getLr());
        this.optimizer.optimize(this.labels, eval);
    }

    @Override
    public int[] predict(INDArray d) {
        INDArray output = this.output(d);
        int[] ret = new int[d.rows()];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
        }
        return ret;
    }

    @Override
    public INDArray labelProbabilities(INDArray examples) {
        List<INDArray> feed = this.feedForward(examples);
        return this.getOutputLayer().labelProbabilities(feed.get(feed.size() - 1));
    }

    @Override
    public void fit(INDArray examples, INDArray labels) {
        this.pretrain(examples, new Object[]{this.defaultConfiguration.getK(), Float.valueOf(this.defaultConfiguration.getLr())});
        if (!this.pretrain) {
            this.input = examples;
        }
        this.finetune(labels);
    }

    @Override
    public void fit(DataSet data) {
        this.fit(data.getFeatureMatrix(), data.getLabels());
    }

    @Override
    public void fit(INDArray examples, int[] labels) {
        this.getOutputLayer().fit(examples, labels);
    }

    public INDArray output(INDArray x) {
        List<INDArray> activations = this.feedForward(x);
        INDArray predicted = activations.get(activations.size() - 1);
        return predicted;
    }

    public INDArray reconstruct(INDArray x, int layerNum) {
        INDArray currInput = x;
        List<INDArray> forward = this.feedForward(currInput);
        return forward.get(layerNum - 1);
    }

    public void printConfiguration() {
        StringBuffer sb = new StringBuffer();
        int count = 0;
        for (NeuralNetConfiguration conf : this.getLayerWiseConfigurations()) {
            sb.append(" Layer " + count++ + " conf " + conf);
        }
        log.info(sb.toString());
    }

    @Override
    public void write(OutputStream os) {
        SerializationUtils.writeObject(this, os);
    }

    @Override
    public void load(InputStream is) {
        BaseMultiLayerNetwork loaded = (BaseMultiLayerNetwork)SerializationUtils.readObject(is);
        this.update(loaded);
    }

    public void update(BaseMultiLayerNetwork network) {
        int i;
        if (network.neuralNets != null && network.getnLayers() > 0) {
            this.setnLayers(network.getNeuralNets().length);
            this.neuralNets = new NeuralNetwork[network.getNeuralNets().length];
            for (i = 0; i < this.neuralNets.length; ++i) {
                if (this.getnLayers() <= i || network.getnLayers() <= i) continue;
                if (network.getNeuralNets()[i] == null) {
                    throw new IllegalStateException("Will not clone uninitialized network, layer " + i + " of network was null");
                }
                this.getNeuralNets()[i] = network.getNeuralNets()[i].clone();
            }
        }
        this.hiddenLayerSizes = network.hiddenLayerSizes;
        this.defaultConfiguration = network.defaultConfiguration;
        this.errorTolerance = network.errorTolerance;
        this.forceNumEpochs = network.forceNumEpochs;
        this.input = network.input;
        this.labels = network.labels;
        this.learningRateUpdate = network.learningRateUpdate;
        this.weightTransforms = network.weightTransforms;
        this.visibleBiasTransforms = network.visibleBiasTransforms;
        this.hiddenBiasTransforms = network.hiddenBiasTransforms;
        this.useDropConnect = network.useDropConnect;
        this.useGaussNewtonVectorProductBackProp = network.useGaussNewtonVectorProductBackProp;
        if (network.neuralNets != null && network.neuralNets.length > 0) {
            this.neuralNets = new NeuralNetwork[network.neuralNets.length];
            for (i = 0; i < this.neuralNets.length; ++i) {
                this.getNeuralNets()[i] = network.getNeuralNets()[i].clone();
            }
        }
    }

    @Override
    public double score(INDArray input, INDArray labels) {
        this.feedForward(input);
        this.setLabels(labels);
        Evaluation eval = new Evaluation();
        eval.eval(labels, this.labelProbabilities(input));
        return eval.f1();
    }

    @Override
    public int numLabels() {
        return this.labels.columns();
    }

    public double score(org.nd4j.linalg.dataset.DataSet data) {
        this.feedForward(data.getFeatureMatrix());
        this.setLabels(data.getLabels());
        return this.score();
    }

    @Override
    public double score() {
        this.feedForward();
        return this.getOutputLayer().score();
    }

    public double score(INDArray param) {
        INDArray params = this.params();
        this.setParameters(param);
        double ret = this.score();
        double regCost = (double)(0.5f * this.defaultConfiguration.getL2()) * (Double)Transforms.pow((INDArray)this.mask.mul(param), (Number)2).sum(Integer.MAX_VALUE).element();
        this.setParameters(params);
        return ret + regCost;
    }

    protected void applyTransforms() {
        if (this.neuralNets == null || this.neuralNets.length < 1) {
            throw new IllegalStateException("Layers not initialized");
        }
        for (int i = 0; i < this.neuralNets.length; ++i) {
            if (this.weightTransforms.containsKey(i)) {
                this.neuralNets[i].setW((INDArray)this.weightTransforms.get(i).apply((Object)this.neuralNets[i].getW()));
            }
            if (this.hiddenBiasTransforms.containsKey(i)) {
                this.neuralNets[i].sethBias((INDArray)this.getHiddenBiasTransforms().get(i).apply((Object)this.neuralNets[i].gethBias()));
            }
            if (!this.visibleBiasTransforms.containsKey(i)) continue;
            this.neuralNets[i].setvBias((INDArray)this.getVisibleBiasTransforms().get(i).apply((Object)this.neuralNets[i].getvBias()));
        }
    }

    public abstract NeuralNetwork createLayer(INDArray var1, INDArray var2, INDArray var3, INDArray var4, int var5);

    public abstract void pretrain(DataSetIterator var1, Object[] var2);

    public abstract void pretrain(INDArray var1, Object[] var2);

    public abstract NeuralNetwork[] createNetworkLayers(int var1);

    public Layer createHiddenLayer(int index, INDArray layerInput) {
        return new Layer.Builder().withInput(layerInput).conf(this.layerWiseConfigurations.get(index)).build();
    }

    public void merge(BaseMultiLayerNetwork network, int batchSize) {
        if (network.getnLayers() != this.getnLayers()) {
            throw new IllegalArgumentException("Unable to merge networks that are not of equal length");
        }
        for (int i = 0; i < this.getnLayers(); ++i) {
            NeuralNetwork n = this.neuralNets[i];
            NeuralNetwork otherNetwork = network.neuralNets[i];
            n.merge(otherNetwork, batchSize);
            this.getLayers()[i].setB(n.gethBias());
            this.getLayers()[i].setW(n.getW());
        }
        this.getOutputLayer().merge(network.getOutputLayer(), batchSize);
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public void setInput(INDArray input) {
        if (input != null && this.neuralNets == null) {
            this.initializeLayers(input);
        }
        this.input = input;
    }

    private void initMask() {
        this.setMask(Nd4j.ones((int)1, (int)this.pack().length()));
        List<Pair<INDArray, INDArray>> mask = this.unPack(this.getMask());
        for (int i = 0; i < mask.size(); ++i) {
            mask.get(i).setSecond(Nd4j.zeros((int)mask.get(i).getSecond().rows(), (int)mask.get(i).getSecond().columns()));
        }
        this.setMask(this.pack(mask));
    }

    public INDArray getInput() {
        return this.input;
    }

    public synchronized NeuralNetwork[] getNeuralNets() {
        return this.neuralNets;
    }

    public boolean forceNumIterations() {
        return this.forceNumEpochs;
    }

    public int[] getHiddenLayerSizes() {
        return this.hiddenLayerSizes;
    }

    public void setHiddenLayerSizes(int[] hiddenLayerSizes) {
        this.hiddenLayerSizes = hiddenLayerSizes;
    }

    public Map<Integer, MatrixTransform> getWeightTransforms() {
        return this.weightTransforms;
    }

    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    public void setForceNumEpochs(boolean forceNumEpochs) {
        this.forceNumEpochs = forceNumEpochs;
    }

    public boolean isSampleFromHiddenActivations() {
        return this.sampleFromHiddenActivations;
    }

    public void setSampleFromHiddenActivations(boolean sampleFromHiddenActivations) {
        this.sampleFromHiddenActivations = sampleFromHiddenActivations;
    }

    public Map<Integer, MatrixTransform> getHiddenBiasTransforms() {
        return this.hiddenBiasTransforms;
    }

    public Map<Integer, MatrixTransform> getVisibleBiasTransforms() {
        return this.visibleBiasTransforms;
    }

    public int getnLayers() {
        return this.neuralNets.length;
    }

    public void setnLayers(int nLayers) {
        this.neuralNets = this.createNetworkLayers(nLayers);
    }

    public void setLayers(NeuralNetwork[] layers) {
        this.neuralNets = layers;
    }

    public boolean isUseGaussNewtonVectorProductBackProp() {
        return this.useGaussNewtonVectorProductBackProp;
    }

    public void setUseGaussNewtonVectorProductBackProp(boolean useGaussNewtonVectorProductBackProp) {
        this.useGaussNewtonVectorProductBackProp = useGaussNewtonVectorProductBackProp;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public void setMask(INDArray mask) {
        this.mask = mask;
    }

    public void clearInput() {
        this.input = null;
        for (int i = 0; i < this.neuralNets.length; ++i) {
            this.neuralNets[i].clearInput();
            this.layers[i].setInput(null);
        }
    }

    public Layer getInputLayer() {
        return this.getLayers()[0];
    }

    public OutputLayer getOutputLayer() {
        return (OutputLayer)this.getLayers()[this.getLayers().length - 1];
    }

    public void setParameters(INDArray params) {
        for (int i = 0; i < this.getNeuralNets().length; ++i) {
            ParamRange range = this.startIndexForLayer(i);
            INDArray w = params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)range.getwStart(), (int)range.getwEnd())});
            INDArray bias = params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)range.getBiasStart(), (int)range.getBiasEnd())});
            int rows = this.getNeuralNets()[i].getW().rows();
            int columns = this.getNeuralNets()[i].getW().columns();
            this.getNeuralNets()[i].setW(w.reshape(rows, columns));
            this.getNeuralNets()[i].sethBias(bias.reshape(this.getNeuralNets()[i].gethBias().rows(), this.getNeuralNets()[i].gethBias().columns()));
        }
        ParamRange range = this.startIndexForLayer(this.getNeuralNets().length);
        INDArray w = params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)range.getwStart(), (int)range.getwEnd())});
        INDArray bias = params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)range.getBiasStart(), (int)range.getBiasEnd())});
        int rows = this.getOutputLayer().getW().rows();
        int columns = this.getOutputLayer().getW().columns();
        this.getOutputLayer().setW(w.reshape(rows, columns));
        this.getOutputLayer().setB(bias.reshape(this.getOutputLayer().getB().rows(), this.getOutputLayer().getB().columns()));
    }

    public ParamRange startIndexForLayer(int layer) {
        int wEnd;
        int start = 0;
        for (int i = 0; i < layer; ++i) {
            start += this.getNeuralNets()[i].getW().length();
            start += this.getNeuralNets()[i].gethBias().length();
        }
        if (layer < this.getNeuralNets().length) {
            wEnd = start + this.getNeuralNets()[layer].getW().length();
            return new ParamRange(start, wEnd, wEnd, wEnd + this.getNeuralNets()[layer].gethBias().length());
        }
        wEnd = start + this.getOutputLayer().getW().length();
        return new ParamRange(start, wEnd, wEnd, wEnd + this.getOutputLayer().getB().length());
    }

    @Override
    public void iterate(INDArray input, Object[] params) {
    }

    @Override
    public void fit(INDArray examples, INDArray labels, Object[] params) {
    }

    @Override
    public void fit(DataSet data, Object[] params) {
    }

    @Override
    public void fit(INDArray examples, int[] labels, Object[] params) {
    }

    @Override
    public void iterate(INDArray examples, int[] labels, Object[] params) {
    }

    public static class Builder<E extends BaseMultiLayerNetwork> {
        protected Class<? extends BaseMultiLayerNetwork> clazz;
        private int[] hiddenLayerSizes;
        private int nLayers;
        private INDArray input;
        private INDArray labels;
        protected Map<Integer, MatrixTransform> weightTransforms = new HashMap<Integer, MatrixTransform>();
        protected boolean backProp = true;
        protected boolean shouldForceEpochs = false;
        private Map<Integer, MatrixTransform> hiddenBiasTransforms = new HashMap<Integer, MatrixTransform>();
        private Map<Integer, MatrixTransform> visibleBiasTransforms = new HashMap<Integer, MatrixTransform>();
        private boolean useDropConnect = false;
        private boolean useGaussNewtonVectorProductBackProp = false;
        protected NeuralNetConfiguration conf;
        protected List<NeuralNetConfiguration> layerWiseConfiguration;
        protected boolean pretrain = true;

        public Builder<E> pretrain(boolean pretrain) {
            this.pretrain = pretrain;
            return this;
        }

        public Builder<E> layerWiseConfiguration(List<NeuralNetConfiguration> layerWiseConfiguration) {
            this.layerWiseConfiguration = layerWiseConfiguration;
            return this;
        }

        public Builder<E> configure(NeuralNetConfiguration conf) {
            this.conf = conf;
            return this;
        }

        public Builder<E> useGaussNewtonVectorProductBackProp(boolean useGaussNewtonVectorProductBackProp) {
            this.useGaussNewtonVectorProductBackProp = useGaussNewtonVectorProductBackProp;
            return this;
        }

        public Builder<E> useDropConnection(boolean useDropConnect) {
            this.useDropConnect = useDropConnect;
            return this;
        }

        public Builder<E> withVisibleBiasTransforms(Map<Integer, MatrixTransform> visibleBiasTransforms) {
            this.visibleBiasTransforms = visibleBiasTransforms;
            return this;
        }

        public Builder<E> withHiddenBiasTransforms(Map<Integer, MatrixTransform> hiddenBiasTransforms) {
            this.hiddenBiasTransforms = hiddenBiasTransforms;
            return this;
        }

        public Builder<E> forceIterations() {
            this.shouldForceEpochs = true;
            return this;
        }

        public Builder<E> disableBackProp() {
            this.backProp = false;
            return this;
        }

        public Builder<E> transformWeightsAt(int layer, MatrixTransform transform) {
            this.weightTransforms.put(layer, transform);
            return this;
        }

        public Builder<E> transformWeightsAt(Map<Integer, MatrixTransform> transforms) {
            this.weightTransforms.putAll(transforms);
            return this;
        }

        public Builder<E> hiddenLayerSizes(Integer ... hiddenLayerSizes) {
            this.hiddenLayerSizes = new int[hiddenLayerSizes.length];
            this.nLayers = hiddenLayerSizes.length;
            for (int i = 0; i < hiddenLayerSizes.length; ++i) {
                this.hiddenLayerSizes[i] = hiddenLayerSizes[i];
            }
            return this;
        }

        public Builder<E> hiddenLayerSizes(int ... hiddenLayerSizes) {
            this.hiddenLayerSizes = hiddenLayerSizes;
            this.nLayers = hiddenLayerSizes.length;
            return this;
        }

        public Builder<E> withInput(INDArray input) {
            this.input = input;
            return this;
        }

        public Builder<E> withLabels(INDArray labels) {
            this.labels = labels;
            return this;
        }

        public Builder<E> withClazz(Class<? extends BaseMultiLayerNetwork> clazz) {
            this.clazz = clazz;
            return this;
        }

        public E buildEmpty() {
            try {
                Constructor<?> c = Dl4jReflection.getEmptyConstructor(this.clazz);
                c.setAccessible(true);
                BaseMultiLayerNetwork ret = (BaseMultiLayerNetwork)c.newInstance(new Object[0]);
                return (E)ret;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        public E build() {
            try {
                Constructor<?> c = Dl4jReflection.getEmptyConstructor(this.clazz);
                c.setAccessible(true);
                BaseMultiLayerNetwork ret = (BaseMultiLayerNetwork)c.newInstance(new Object[0]);
                ret.setDefaultConfiguration(this.conf);
                ret.useGaussNewtonVectorProductBackProp = this.useGaussNewtonVectorProductBackProp;
                ret.setUseDropConnect(this.useDropConnect);
                ret.setInput(this.input);
                ret.setLabels(this.labels);
                ret.setHiddenLayerSizes(this.hiddenLayerSizes);
                ret.setnLayers(this.nLayers);
                ret.setLayerWiseConfigurations(this.layerWiseConfiguration);
                ret.neuralNets = new NeuralNetwork[this.nLayers];
                ret.setInput(this.input);
                ret.setLabels(this.labels);
                ret.setForceNumEpochs(this.shouldForceEpochs);
                ret.getWeightTransforms().putAll(this.weightTransforms);
                ret.getVisibleBiasTransforms().putAll(this.visibleBiasTransforms);
                ret.getHiddenBiasTransforms().putAll(this.hiddenBiasTransforms);
                ret.layerWiseConfigurations = this.layerWiseConfiguration;
                ret.pretrain = this.pretrain;
                if (ret.defaultConfiguration == null) {
                    ret.defaultConfiguration = this.layerWiseConfiguration.get(0);
                }
                if (this.hiddenLayerSizes == null) {
                    throw new IllegalStateException("Unable to build network, no hidden layer sizes defined");
                }
                return (E)ret;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    public static class ParamRange
    implements Serializable {
        private int wStart;
        private int wEnd;
        private int biasStart;
        private int biasEnd;

        private ParamRange(int wStart, int wEnd, int biasStart, int biasEnd) {
            this.wStart = wStart;
            this.wEnd = wEnd;
            this.biasStart = biasStart;
            this.biasEnd = biasEnd;
        }

        public int getwStart() {
            return this.wStart;
        }

        public void setwStart(int wStart) {
            this.wStart = wStart;
        }

        public int getwEnd() {
            return this.wEnd;
        }

        public void setwEnd(int wEnd) {
            this.wEnd = wEnd;
        }

        public int getBiasStart() {
            return this.biasStart;
        }

        public void setBiasStart(int biasStart) {
            this.biasStart = biasStart;
        }

        public int getBiasEnd() {
            return this.biasEnd;
        }

        public void setBiasEnd(int biasEnd) {
            this.biasEnd = biasEnd;
        }
    }
}

