/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn;

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.HiddenLayer;
import org.deeplearning4j.nn.LogisticRegression;
import org.deeplearning4j.nn.NeuralNetwork;
import org.deeplearning4j.nn.Persistable;
import org.deeplearning4j.nn.activation.ActivationFunction;
import org.deeplearning4j.nn.activation.Sigmoid;
import org.deeplearning4j.optimize.MultiLayerNetworkOptimizer;
import org.deeplearning4j.transformation.MatrixTransform;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseMultiLayerNetwork
implements Serializable,
Persistable {
    private static Logger log = LoggerFactory.getLogger(BaseMultiLayerNetwork.class);
    private static final long serialVersionUID = -5029161847383716484L;
    private int nIns;
    private int[] hiddenLayerSizes;
    private int nOuts;
    private int nLayers;
    private HiddenLayer[] sigmoidLayers;
    private LogisticRegression logLayer;
    private RandomGenerator rng;
    private RealDistribution dist;
    private double momentum = 0.1;
    private DoubleMatrix input;
    private DoubleMatrix labels;
    private MultiLayerNetworkOptimizer optimizer;
    private ActivationFunction activation = new Sigmoid();
    private boolean toDecode;
    private double l2 = 0.01;
    private boolean shouldInit = true;
    private double fanIn = -1.0;
    private int renderWeightsEveryNEpochs = -1;
    private boolean useRegularization = true;
    private Map<Integer, MatrixTransform> weightTransforms = new HashMap<Integer, MatrixTransform>();
    private boolean shouldBackProp = true;
    private boolean forceNumEpochs = false;
    private double sparsity = 0.0;
    private DoubleMatrix columnSums;
    private DoubleMatrix columnMeans;
    private DoubleMatrix columnStds;
    public double learningRateUpdate = 0.95;
    public NeuralNetwork[] layers;
    public double errorTolerance = 1.0E-4;

    public BaseMultiLayerNetwork() {
    }

    public BaseMultiLayerNetwork(int n_ins, int[] hidden_layer_sizes, int n_outs, int n_layers, RandomGenerator rng) {
        this(n_ins, hidden_layer_sizes, n_outs, n_layers, rng, null, null);
    }

    public BaseMultiLayerNetwork(int nIn, int[] hiddenLayerSizes, int nOuts, int nLayers, RandomGenerator rng, DoubleMatrix input, DoubleMatrix labels) {
        this.nIns = nIn;
        this.hiddenLayerSizes = hiddenLayerSizes;
        this.input = input.dup();
        this.labels = labels.dup();
        if (hiddenLayerSizes.length != nLayers) {
            throw new IllegalArgumentException("The number of hidden layer sizes must be equivalent to the nLayers argument which is a value of " + nLayers);
        }
        this.nOuts = nOuts;
        this.nLayers = nLayers;
        this.sigmoidLayers = new HiddenLayer[nLayers];
        this.layers = this.createNetworkLayers(nLayers);
        this.rng = rng == null ? new MersenneTwister(123) : rng;
        if (input != null) {
            this.initializeLayers(input);
        }
    }

    public double fanIn() {
        if (this.fanIn < 0.0) {
            return 1.0 / (double)this.nIns;
        }
        return this.fanIn;
    }

    private void dimensionCheck() {
        for (int i = 0; i < this.nLayers; ++i) {
            HiddenLayer h = this.sigmoidLayers[i];
            NeuralNetwork network = this.layers[i];
            h.getW().assertSameSize(network.getW());
            h.getB().assertSameSize(network.gethBias());
            if (i >= this.nLayers - 1) continue;
            HiddenLayer h1 = this.sigmoidLayers[i + 1];
            NeuralNetwork network1 = this.layers[i + 1];
            if (h1.getnIn() != h.getnOut()) {
                throw new IllegalStateException("Invalid structure: hidden layer in for " + (i + 1) + " not equal to number of ins " + i);
            }
            if (network.getnHidden() == network1.getnVisible()) continue;
            throw new IllegalStateException("Invalid structure: network hidden for " + (i + 1) + " not equal to number of visible " + i);
        }
        if (this.sigmoidLayers[this.sigmoidLayers.length - 1].getnOut() != this.logLayer.getnIn()) {
            throw new IllegalStateException("Number of outputs for final hidden layer not equal to the number of logistic input units for output layer");
        }
    }

    public void asDecoder(BaseMultiLayerNetwork network) {
        this.createNetworkLayers(network.nLayers + 1);
        this.layers = new NeuralNetwork[network.nLayers];
        this.sigmoidLayers = new HiddenLayer[network.nLayers];
        this.hiddenLayerSizes = new int[network.nLayers];
        this.nIns = network.nOuts;
        this.nOuts = network.nIns;
        this.nLayers = network.nLayers;
        this.dist = network.dist;
        int count = 0;
        for (int i = network.nLayers - 1; i >= 0; --i) {
            this.layers[count] = network.layers[i].clone();
            this.layers[count].setRng(network.layers[i].getRng());
            this.hiddenLayerSizes[count] = network.hiddenLayerSizes[i];
            ++count;
        }
        this.rng = network.rng;
        this.shouldInit = false;
    }

    public void initializeLayers(DoubleMatrix input) {
        if (input == null) {
            throw new IllegalArgumentException("Unable to initialize layers with empty input");
        }
        if (input.columns != this.nIns) {
            throw new IllegalArgumentException(String.format("Unable to train on number of inputs; columns should be equal to number of inputs. Number of inputs was %d while number of columns was %d", this.nIns, input.columns));
        }
        if (this.layers == null) {
            this.layers = new NeuralNetwork[this.nLayers];
        }
        for (int i = 0; i < this.hiddenLayerSizes.length; ++i) {
            if (this.hiddenLayerSizes[i] >= 1) continue;
            throw new IllegalArgumentException("All hidden layer sizes must be >= 1");
        }
        this.input = input.dup();
        DoubleMatrix layerInput = input;
        for (int i = 0; i < this.nLayers; ++i) {
            int inputSize = i == 0 ? this.nIns : this.hiddenLayerSizes[i - 1];
            if (i == 0) {
                this.sigmoidLayers[i] = new HiddenLayer(inputSize, this.hiddenLayerSizes[i], null, null, this.rng, layerInput);
                this.sigmoidLayers[i].setActivationFunction(this.activation);
            } else {
                layerInput = this.sigmoidLayers[i - 1].sample_h_given_v();
                this.sigmoidLayers[i] = new HiddenLayer(inputSize, this.hiddenLayerSizes[i], null, null, this.rng, layerInput);
                this.sigmoidLayers[i].setActivationFunction(this.activation);
            }
            this.layers[i] = this.createLayer(layerInput, inputSize, this.hiddenLayerSizes[i], this.sigmoidLayers[i].getW(), this.sigmoidLayers[i].getB(), null, this.rng, i);
        }
        this.logLayer = new LogisticRegression(layerInput, this.hiddenLayerSizes[this.nLayers - 1], this.nOuts);
        this.logLayer.setUseRegularization(this.isUseRegularization());
        this.logLayer.setL2(this.getL2());
        this.dimensionCheck();
        this.applyTransforms();
    }

    public synchronized int getnIns() {
        return this.nIns;
    }

    public synchronized void setnIns(int nIns) {
        this.nIns = nIns;
    }

    public synchronized int getnOuts() {
        return this.nOuts;
    }

    public synchronized void setnOuts(int nOuts) {
        this.nOuts = nOuts;
    }

    public synchronized int getnLayers() {
        return this.nLayers;
    }

    public synchronized void setnLayers(int nLayers) {
        this.nLayers = nLayers;
    }

    public synchronized double getMomentum() {
        return this.momentum;
    }

    public synchronized void setMomentum(double momentum) {
        this.momentum = momentum;
    }

    public synchronized double getL2() {
        return this.l2;
    }

    public synchronized void setL2(double l2) {
        this.l2 = l2;
    }

    public synchronized boolean isUseRegularization() {
        return this.useRegularization;
    }

    public synchronized void setUseRegularization(boolean useRegularization) {
        this.useRegularization = useRegularization;
    }

    public synchronized void setSigmoidLayers(HiddenLayer[] sigmoidLayers) {
        this.sigmoidLayers = sigmoidLayers;
    }

    public synchronized void setLogLayer(LogisticRegression logLayer) {
        this.logLayer = logLayer;
    }

    public synchronized void setShouldBackProp(boolean shouldBackProp) {
        this.shouldBackProp = shouldBackProp;
    }

    public synchronized void setLayers(NeuralNetwork[] layers) {
        this.layers = layers;
    }

    protected void initializeNetwork(NeuralNetwork network) {
        network.setFanIn(this.fanIn);
        network.setRenderEpochs(this.renderWeightsEveryNEpochs);
    }

    public void finetune(double lr, int epochs) {
        this.finetune(this.labels, lr, epochs);
    }

    public synchronized DoubleMatrix getLabels() {
        return this.labels;
    }

    public synchronized LogisticRegression getLogLayer() {
        return this.logLayer;
    }

    public synchronized void setInput(DoubleMatrix input) {
        this.input = input;
    }

    public synchronized DoubleMatrix getInput() {
        return this.input;
    }

    public synchronized HiddenLayer[] getSigmoidLayers() {
        return this.sigmoidLayers;
    }

    public synchronized NeuralNetwork[] getLayers() {
        return this.layers;
    }

    public synchronized List<DoubleMatrix> feedForward(DoubleMatrix input) {
        if (this.input == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        DoubleMatrix currInput = input;
        ArrayList<DoubleMatrix> activations = new ArrayList<DoubleMatrix>();
        activations.add(currInput);
        for (int i = 0; i < this.nLayers; ++i) {
            this.getLayers()[i].setInput(currInput);
            currInput = this.getSigmoidLayers()[i].activate(currInput);
            activations.add(currInput);
        }
        activations.add(this.getLogLayer().predict(currInput));
        return activations;
    }

    private synchronized void computeDeltas(List<Pair<DoubleMatrix, DoubleMatrix>> deltaRet) {
        int i;
        DoubleMatrix[] gradients = new DoubleMatrix[this.nLayers + 2];
        DoubleMatrix[] deltas = new DoubleMatrix[this.nLayers + 2];
        ActivationFunction derivative = this.sigmoidLayers[0].getActivationFunction();
        DoubleMatrix delta = null;
        List<DoubleMatrix> activations = this.feedForward(this.getInput());
        ArrayList<DoubleMatrix> weights = new ArrayList<DoubleMatrix>();
        for (int j = 0; j < this.getLayers().length; ++j) {
            weights.add(this.getLayers()[j].getW());
        }
        weights.add(this.getLogLayer().getW());
        DoubleMatrix labels = this.predict(this.getInput());
        for (i = this.nLayers + 1; i >= 0; --i) {
            DoubleMatrix error;
            if (i >= this.nLayers + 1) {
                DoubleMatrix initialDelta;
                DoubleMatrix z = activations.get(i);
                delta = labels.sub(z).neg();
                deltas[i] = initialDelta = delta.mul(derivative.applyDerivative(z));
                continue;
            }
            delta = deltas[i + 1];
            DoubleMatrix w = ((DoubleMatrix)weights.get(i)).transpose();
            DoubleMatrix z = activations.get(i);
            DoubleMatrix a = activations.get(i);
            deltas[i] = error = delta.mmul(w);
            deltas[i] = error = error.mul(derivative.applyDerivative(z));
            DoubleMatrix lastLayerDelta = deltas[i + 1].transpose();
            DoubleMatrix newGradient = lastLayerDelta.mmul(a);
            gradients[i] = newGradient.div((double)this.getInput().rows);
        }
        for (i = 0; i < gradients.length; ++i) {
            deltaRet.add(new Pair<DoubleMatrix, DoubleMatrix>(gradients[i], deltas[i]));
        }
    }

    protected BaseMultiLayerNetwork clone() {
        Object ret = new Builder().withClazz(this.getClass()).buildEmpty();
        ((BaseMultiLayerNetwork)ret).update(this);
        return ret;
    }

    public void backProp(double lr, int epochs) {
        Double lastEntropy = null;
        BaseMultiLayerNetwork revert = this.clone();
        if (this.forceNumEpochs) {
            for (int i = 0; i < epochs; ++i) {
                this.backPropStep(lastEntropy, revert, lr, i);
                lastEntropy = this.negativeLogLikelihood();
            }
        } else {
            int count = 0;
            while (this.backPropStep(lastEntropy, revert, lr, count)) {
                ++count;
                lastEntropy = this.negativeLogLikelihood();
            }
        }
    }

    protected boolean backPropStep(Double lastEntropy, BaseMultiLayerNetwork revert, double lr, int epoch) {
        double error = this.negativeLogLikelihood();
        if (lastEntropy == null) {
            lastEntropy = error;
        } else {
            if (error == lastEntropy) {
                log.info("Converged; no more stepping appears to do anything");
                return false;
            }
            if (error > lastEntropy || Double.isNaN(error) || Double.isInfinite(error)) {
                log.info("Error greater than previous; found global minima; converging");
                this.update(revert);
                return false;
            }
            if (error < lastEntropy) {
                lastEntropy = error;
                revert = this.clone();
                log.info("Found better error on epoch " + epoch + " " + lastEntropy);
            }
        }
        ArrayList<Pair<DoubleMatrix, DoubleMatrix>> deltas = new ArrayList<Pair<DoubleMatrix, DoubleMatrix>>();
        this.computeDeltas(deltas);
        for (int l = 0; l < this.nLayers; ++l) {
            DoubleMatrix add = ((DoubleMatrix)((Pair)deltas.get(l)).getFirst()).div((double)this.input.rows).mul(lr);
            add.divi((double)this.input.rows);
            if (this.useRegularization) {
                add.muli(this.layers[l].getW().mul(this.l2));
            }
            this.layers[l].setW(this.layers[l].getW().add(add.mul(lr)));
            this.sigmoidLayers[l].setW(this.layers[l].getW());
            DoubleMatrix deltaColumnSums = ((DoubleMatrix)((Pair)deltas.get(l + 1)).getSecond()).columnSums();
            deltaColumnSums.divi((double)this.input.rows);
            this.layers[l].gethBias().addi(deltaColumnSums.mul(lr));
            this.sigmoidLayers[l].setB(this.getLayers()[l].gethBias());
        }
        this.logLayer.getW().addi((DoubleMatrix)((Pair)deltas.get(this.nLayers)).getFirst());
        return true;
    }

    public void finetune(DoubleMatrix labels, double lr, int epochs) {
        if (labels != null) {
            this.labels = labels;
        }
        this.optimizer = new MultiLayerNetworkOptimizer(this, lr);
        this.optimizer.optimize(this.labels, lr, epochs);
    }

    public DoubleMatrix predict(DoubleMatrix x) {
        DoubleMatrix col;
        int i;
        if (this.columnSums != null) {
            for (i = 0; i < x.columns; ++i) {
                col = x.getColumn(i);
                col = col.div(this.columnSums.get(0, i));
                x.putColumn(i, col);
            }
        }
        if (this.columnMeans != null) {
            for (i = 0; i < x.columns; ++i) {
                col = x.getColumn(i);
                col = col.sub(this.columnMeans.get(0, i));
                x.putColumn(i, col);
            }
        }
        if (this.columnStds != null) {
            for (i = 0; i < x.columns; ++i) {
                col = x.getColumn(i);
                col = col.div(this.columnStds.get(0, i));
                x.putColumn(i, col);
            }
        }
        if (this.input == null) {
            this.initializeLayers(x);
        }
        DoubleMatrix input = x;
        for (int i2 = 0; i2 < this.nLayers; ++i2) {
            input = this.sigmoidLayers[i2].activate(input);
        }
        return this.logLayer.predict(input);
    }

    public DoubleMatrix reconstruct(DoubleMatrix x, int layerNum) {
        DoubleMatrix col;
        int i;
        if (layerNum > this.nLayers || layerNum < 0) {
            throw new IllegalArgumentException("Layer number " + layerNum + " does not exist");
        }
        if (this.columnSums != null) {
            for (i = 0; i < x.columns; ++i) {
                col = x.getColumn(i);
                col = col.div(this.columnSums.get(0, i));
                x.putColumn(i, col);
            }
        }
        if (this.columnMeans != null) {
            for (i = 0; i < x.columns; ++i) {
                col = x.getColumn(i);
                col = col.sub(this.columnMeans.get(0, i));
                x.putColumn(i, col);
            }
        }
        if (this.columnStds != null) {
            for (i = 0; i < x.columns; ++i) {
                col = x.getColumn(i);
                col = col.div(this.columnStds.get(0, i));
                x.putColumn(i, col);
            }
        }
        DoubleMatrix input = x;
        for (int i2 = 0; i2 < layerNum; ++i2) {
            HiddenLayer layer = this.sigmoidLayers[i2];
            input = layer.activate(input);
        }
        return input;
    }

    public DoubleMatrix reconstruct(DoubleMatrix x) {
        return this.reconstruct(x, this.sigmoidLayers.length);
    }

    @Override
    public void write(OutputStream os) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(os);
            oos.writeObject(this);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void load(InputStream is) {
        try {
            ObjectInputStream ois = new ObjectInputStream(is);
            BaseMultiLayerNetwork loaded = (BaseMultiLayerNetwork)ois.readObject();
            this.update(loaded);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static BaseMultiLayerNetwork loadFromFile(InputStream is) {
        try {
            ObjectInputStream ois = new ObjectInputStream(is);
            log.info("Loading network model...");
            BaseMultiLayerNetwork loaded = (BaseMultiLayerNetwork)ois.readObject();
            return loaded;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    protected void update(BaseMultiLayerNetwork network) {
        int i;
        this.layers = new NeuralNetwork[network.layers.length];
        for (i = 0; i < this.layers.length; ++i) {
            this.layers[i] = network.layers[i].clone();
        }
        this.hiddenLayerSizes = network.hiddenLayerSizes;
        this.logLayer = network.logLayer.clone();
        this.nIns = network.nIns;
        this.nLayers = network.nLayers;
        this.nOuts = network.nOuts;
        this.rng = network.rng;
        this.dist = network.dist;
        this.activation = network.activation;
        this.useRegularization = network.useRegularization;
        this.columnMeans = network.columnMeans;
        this.columnStds = network.columnStds;
        this.columnSums = network.columnSums;
        this.errorTolerance = network.errorTolerance;
        this.forceNumEpochs = network.forceNumEpochs;
        this.input = network.input;
        this.l2 = network.l2;
        this.fanIn = network.fanIn;
        this.labels = network.labels;
        this.momentum = network.momentum;
        this.learningRateUpdate = network.learningRateUpdate;
        this.shouldBackProp = network.shouldBackProp;
        this.weightTransforms = network.weightTransforms;
        this.sparsity = network.sparsity;
        this.toDecode = network.toDecode;
        this.sigmoidLayers = new HiddenLayer[network.sigmoidLayers.length];
        for (i = 0; i < this.sigmoidLayers.length; ++i) {
            this.sigmoidLayers[i] = network.sigmoidLayers[i].clone();
        }
    }

    public synchronized double negativeLogLikelihood() {
        double ret = 0.0;
        for (int i = 0; i < this.nLayers; ++i) {
            double sum = MatrixFunctions.pow((DoubleMatrix)this.layers[i].getW(), (double)2.0).sum() / 2.0;
            if (this.useRegularization) {
                ret += sum * this.l2;
                continue;
            }
            ret += sum;
        }
        double sum = MatrixFunctions.pow((DoubleMatrix)this.logLayer.getW(), (double)2.0).sum() / 2.0;
        ret = this.useRegularization ? (ret += sum * this.l2) : (ret += sum);
        return ret;
    }

    public abstract void trainNetwork(DoubleMatrix var1, DoubleMatrix var2, Object[] var3);

    protected void applyTransforms() {
        if (this.layers == null || this.layers.length < 1) {
            throw new IllegalStateException("Layers not initialized");
        }
        for (int i = 0; i < this.layers.length; ++i) {
            if (!this.weightTransforms.containsKey(i)) continue;
            this.layers[i].setW((DoubleMatrix)this.weightTransforms.get(i).apply(this.layers[i].getW()));
        }
    }

    public boolean isShouldBackProp() {
        return this.shouldBackProp;
    }

    public abstract NeuralNetwork createLayer(DoubleMatrix var1, int var2, int var3, DoubleMatrix var4, DoubleMatrix var5, DoubleMatrix var6, RandomGenerator var7, int var8);

    public abstract NeuralNetwork[] createNetworkLayers(int var1);

    public void merge(BaseMultiLayerNetwork network, int batchSize) {
        if (network.nLayers != this.nLayers) {
            throw new IllegalArgumentException("Unable to merge networks that are not of equal length");
        }
        for (int i = 0; i < this.nLayers; ++i) {
            NeuralNetwork n = this.layers[i];
            NeuralNetwork otherNetwork = network.layers[i];
            n.merge(otherNetwork, batchSize);
            this.getSigmoidLayers()[i].setB(n.gethBias());
            this.getSigmoidLayers()[i].setW(n.getW());
        }
        this.getLogLayer().merge(network.logLayer, batchSize);
    }

    public void encode(BaseMultiLayerNetwork network) {
        this.createNetworkLayers(network.nLayers);
        this.layers = new NeuralNetwork[network.nLayers];
        this.hiddenLayerSizes = new int[this.nLayers];
        int count = 0;
        for (int i = this.nLayers - 1; i > 0; --i) {
            NeuralNetwork n = network.layers[i].clone();
            HiddenLayer l = network.sigmoidLayers[i].clone();
            this.layers[count] = n;
            this.sigmoidLayers[count] = l;
            this.hiddenLayerSizes[count] = network.hiddenLayerSizes[i];
            ++count;
        }
        this.logLayer = new LogisticRegression(this.hiddenLayerSizes[this.nLayers - 1], network.input.columns);
    }

    public boolean isForceNumEpochs() {
        return this.forceNumEpochs;
    }

    public DoubleMatrix getColumnSums() {
        return this.columnSums;
    }

    public void setColumnSums(DoubleMatrix columnSums) {
        this.columnSums = columnSums;
    }

    public synchronized int[] getHiddenLayerSizes() {
        return this.hiddenLayerSizes;
    }

    public synchronized void setHiddenLayerSizes(int[] hiddenLayerSizes) {
        this.hiddenLayerSizes = hiddenLayerSizes;
    }

    public synchronized RandomGenerator getRng() {
        return this.rng;
    }

    public synchronized void setRng(RandomGenerator rng) {
        this.rng = rng;
    }

    public synchronized RealDistribution getDist() {
        return this.dist;
    }

    public synchronized void setDist(RealDistribution dist) {
        this.dist = dist;
    }

    public synchronized MultiLayerNetworkOptimizer getOptimizer() {
        return this.optimizer;
    }

    public synchronized void setOptimizer(MultiLayerNetworkOptimizer optimizer) {
        this.optimizer = optimizer;
    }

    public synchronized ActivationFunction getActivation() {
        return this.activation;
    }

    public synchronized void setActivation(ActivationFunction activation) {
        this.activation = activation;
    }

    public synchronized boolean isToDecode() {
        return this.toDecode;
    }

    public synchronized void setToDecode(boolean toDecode) {
        this.toDecode = toDecode;
    }

    public synchronized boolean isShouldInit() {
        return this.shouldInit;
    }

    public synchronized void setShouldInit(boolean shouldInit) {
        this.shouldInit = shouldInit;
    }

    public synchronized double getFanIn() {
        return this.fanIn;
    }

    public synchronized void setFanIn(double fanIn) {
        this.fanIn = fanIn;
    }

    public synchronized int getRenderWeightsEveryNEpochs() {
        return this.renderWeightsEveryNEpochs;
    }

    public synchronized void setRenderWeightsEveryNEpochs(int renderWeightsEveryNEpochs) {
        this.renderWeightsEveryNEpochs = renderWeightsEveryNEpochs;
    }

    public synchronized Map<Integer, MatrixTransform> getWeightTransforms() {
        return this.weightTransforms;
    }

    public synchronized void setWeightTransforms(Map<Integer, MatrixTransform> weightTransforms) {
        this.weightTransforms = weightTransforms;
    }

    public synchronized double getSparsity() {
        return this.sparsity;
    }

    public synchronized void setSparsity(double sparsity) {
        this.sparsity = sparsity;
    }

    public synchronized double getLearningRateUpdate() {
        return this.learningRateUpdate;
    }

    public synchronized void setLearningRateUpdate(double learningRateUpdate) {
        this.learningRateUpdate = learningRateUpdate;
    }

    public synchronized double getErrorTolerance() {
        return this.errorTolerance;
    }

    public synchronized void setErrorTolerance(double errorTolerance) {
        this.errorTolerance = errorTolerance;
    }

    public synchronized void setLabels(DoubleMatrix labels) {
        this.labels = labels;
    }

    public synchronized void setForceNumEpochs(boolean forceNumEpochs) {
        this.forceNumEpochs = forceNumEpochs;
    }

    public DoubleMatrix getColumnMeans() {
        return this.columnMeans;
    }

    public void setColumnMeans(DoubleMatrix columnMeans) {
        this.columnMeans = columnMeans;
    }

    public DoubleMatrix getColumnStds() {
        return this.columnStds;
    }

    public void setColumnStds(DoubleMatrix columnStds) {
        this.columnStds = columnStds;
    }

    public static class Builder<E extends BaseMultiLayerNetwork> {
        protected Class<? extends BaseMultiLayerNetwork> clazz;
        private E ret;
        private int nIns;
        private int[] hiddenLayerSizes;
        private int nOuts;
        private int nLayers;
        private RandomGenerator rng = new MersenneTwister(1234);
        private DoubleMatrix input;
        private DoubleMatrix labels;
        private ActivationFunction activation;
        private boolean decode = false;
        private double fanIn = -1.0;
        private int renderWeithsEveryNEpochs = -1;
        private double l2 = 0.01;
        private boolean useRegularization = true;
        private double momentum;
        private RealDistribution dist;
        protected Map<Integer, MatrixTransform> weightTransforms = new HashMap<Integer, MatrixTransform>();
        protected boolean backProp = true;
        protected boolean shouldForceEpochs = false;
        private double sparsity = 0.0;

        public Builder<E> withSparsity(double sparsity) {
            this.sparsity = sparsity;
            return this;
        }

        public Builder<E> forceEpochs() {
            this.shouldForceEpochs = true;
            return this;
        }

        public Builder<E> disableBackProp() {
            this.backProp = false;
            return this;
        }

        public Builder<E> transformWeightsAt(int layer, MatrixTransform transform) {
            this.weightTransforms.put(layer, transform);
            return this;
        }

        public Builder<E> transformWeightsAt(Map<Integer, MatrixTransform> transforms) {
            this.weightTransforms.putAll(transforms);
            return this;
        }

        public Builder<E> withDist(RealDistribution dist) {
            this.dist = dist;
            return this;
        }

        public Builder<E> withMomentum(double momentum) {
            this.momentum = momentum;
            return this;
        }

        public Builder<E> useRegularization(boolean useRegularization) {
            this.useRegularization = useRegularization;
            return this;
        }

        public Builder<E> withL2(double l2) {
            this.l2 = l2;
            return this;
        }

        public Builder<E> renderWeights(int everyN) {
            this.renderWeithsEveryNEpochs = everyN;
            return this;
        }

        public Builder<E> withFanIn(Double fanIn) {
            this.fanIn = fanIn;
            return this;
        }

        public Builder<E> withActivation(ActivationFunction activation) {
            this.activation = activation;
            return this;
        }

        public Builder<E> numberOfInputs(int nIns) {
            this.nIns = nIns;
            return this;
        }

        public Builder<E> decodeNetwork(boolean decode) {
            this.decode = decode;
            return this;
        }

        public Builder<E> hiddenLayerSizes(int[] hiddenLayerSizes) {
            this.hiddenLayerSizes = hiddenLayerSizes;
            this.nLayers = hiddenLayerSizes.length;
            return this;
        }

        public Builder<E> numberOfOutPuts(int nOuts) {
            this.nOuts = nOuts;
            return this;
        }

        public Builder<E> withRng(RandomGenerator gen) {
            this.rng = gen;
            return this;
        }

        public Builder<E> withInput(DoubleMatrix input) {
            this.input = input;
            return this;
        }

        public Builder<E> withLabels(DoubleMatrix labels) {
            this.labels = labels;
            return this;
        }

        public Builder<E> withClazz(Class<? extends BaseMultiLayerNetwork> clazz) {
            this.clazz = clazz;
            return this;
        }

        public E buildEmpty() {
            try {
                return (E)this.clazz.newInstance();
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        public E build() {
            try {
                this.ret = this.clazz.newInstance();
                ((BaseMultiLayerNetwork)this.ret).setInput(this.input);
                ((BaseMultiLayerNetwork)this.ret).setnOuts(this.nOuts);
                ((BaseMultiLayerNetwork)this.ret).setnIns(this.nIns);
                ((BaseMultiLayerNetwork)this.ret).setLabels(this.labels);
                ((BaseMultiLayerNetwork)this.ret).setHiddenLayerSizes(this.hiddenLayerSizes);
                ((BaseMultiLayerNetwork)this.ret).setnLayers(this.nLayers);
                ((BaseMultiLayerNetwork)this.ret).setRng(this.rng);
                ((BaseMultiLayerNetwork)this.ret).setShouldBackProp(this.backProp);
                ((BaseMultiLayerNetwork)this.ret).setSigmoidLayers(new HiddenLayer[((BaseMultiLayerNetwork)this.ret).getnLayers()]);
                ((BaseMultiLayerNetwork)this.ret).setToDecode(this.decode);
                ((BaseMultiLayerNetwork)this.ret).setInput(this.input);
                ((BaseMultiLayerNetwork)this.ret).setMomentum(this.momentum);
                ((BaseMultiLayerNetwork)this.ret).setLabels(this.labels);
                ((BaseMultiLayerNetwork)this.ret).setFanIn(this.fanIn);
                ((BaseMultiLayerNetwork)this.ret).setSparsity(this.sparsity);
                ((BaseMultiLayerNetwork)this.ret).setRenderWeightsEveryNEpochs(this.renderWeithsEveryNEpochs);
                ((BaseMultiLayerNetwork)this.ret).setL2(this.l2);
                ((BaseMultiLayerNetwork)this.ret).setForceNumEpochs(this.shouldForceEpochs);
                ((BaseMultiLayerNetwork)this.ret).setUseRegularization(this.useRegularization);
                if (this.activation != null) {
                    ((BaseMultiLayerNetwork)this.ret).setActivation(this.activation);
                }
                if (this.dist != null) {
                    ((BaseMultiLayerNetwork)this.ret).setDist(this.dist);
                }
                ((BaseMultiLayerNetwork)this.ret).getWeightTransforms().putAll(this.weightTransforms);
                return this.ret;
            }
            catch (IllegalAccessException | InstantiationException e) {
                throw new RuntimeException(e);
            }
        }
    }
}

