/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn;

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.lang.reflect.Constructor;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.nn.NeuralNetwork;
import org.deeplearning4j.nn.Persistable;
import org.deeplearning4j.optimize.NeuralNetworkOptimizer;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

public abstract class BaseNeuralNetwork
implements NeuralNetwork,
Persistable {
    private static final long serialVersionUID = -7074102204433996574L;
    public int nVisible;
    public int nHidden;
    public DoubleMatrix W;
    public DoubleMatrix hBias;
    public DoubleMatrix vBias;
    public RandomGenerator rng;
    public DoubleMatrix input;
    public double sparsity = 0.01;
    public double momentum = 0.1;
    public transient RealDistribution dist = new NormalDistribution(this.rng, 0.0, 0.01, 1.0E-9);
    public double l2 = 0.1;
    public transient NeuralNetworkOptimizer optimizer;
    public int renderWeightsEveryNumEpochs = -1;
    public double fanIn = -1.0;
    public boolean useRegularization = true;

    public BaseNeuralNetwork() {
    }

    public BaseNeuralNetwork(int nVisible, int nHidden, DoubleMatrix W, DoubleMatrix hbias, DoubleMatrix vbias, RandomGenerator rng, double fanIn, RealDistribution dist) {
        this(null, nVisible, nHidden, W, hbias, vbias, rng, fanIn, dist);
    }

    public BaseNeuralNetwork(DoubleMatrix input, int nVisible, int nHidden, DoubleMatrix W, DoubleMatrix hbias, DoubleMatrix vbias, RandomGenerator rng, double fanIn, RealDistribution dist) {
        this.nVisible = nVisible;
        this.dist = dist != null ? dist : new NormalDistribution(rng, 0.0, 0.01, 1.0E-9);
        this.nHidden = nHidden;
        this.fanIn = fanIn;
        this.input = input;
        this.rng = rng == null ? new MersenneTwister(1234) : rng;
        this.W = W;
        this.vBias = vbias;
        this.hBias = hbias;
        this.initWeights();
    }

    @Override
    public double l2RegularizedCoefficient() {
        return MatrixFunctions.pow((DoubleMatrix)this.getW(), (double)2.0).sum() / 2.0 * this.l2;
    }

    protected void initWeights() {
        if (this.nVisible < 1) {
            throw new IllegalStateException("Number of visible can not be less than 1");
        }
        if (this.nHidden < 1) {
            throw new IllegalStateException("Number of hidden can not be less than 1");
        }
        if (this.dist == null) {
            this.dist = new NormalDistribution(this.rng, 0.0, 0.01, 1.0E-9);
        }
        if (this.W == null) {
            this.W = DoubleMatrix.zeros((int)this.nVisible, (int)this.nHidden);
            for (int i = 0; i < this.W.rows; ++i) {
                this.W.putRow(i, new DoubleMatrix(this.dist.sample(this.W.columns)));
            }
        }
        if (this.hBias == null) {
            this.hBias = DoubleMatrix.zeros((int)this.nHidden);
        }
        if (this.vBias == null) {
            this.vBias = this.input != null ? DoubleMatrix.zeros((int)this.nVisible) : DoubleMatrix.zeros((int)this.nVisible);
        }
    }

    @Override
    public void setRenderEpochs(int renderEpochs) {
        this.renderWeightsEveryNumEpochs = renderEpochs;
    }

    @Override
    public int getRenderEpochs() {
        return this.renderWeightsEveryNumEpochs;
    }

    @Override
    public double fanIn() {
        return this.fanIn < 0.0 ? (double)(1 / this.nVisible) : this.fanIn;
    }

    @Override
    public void setFanIn(double fanIn) {
        this.fanIn = fanIn;
    }

    public void jostleWeighMatrix() {
        DoubleMatrix W = DoubleMatrix.zeros((int)this.nVisible, (int)this.nHidden);
        for (int i = 0; i < this.W.rows; ++i) {
            W.putRow(i, new DoubleMatrix(this.dist.sample(this.W.columns)));
        }
    }

    @Override
    public NeuralNetwork transpose() {
        try {
            NeuralNetwork ret = (NeuralNetwork)this.getClass().newInstance();
            ret.sethBias(this.hBias.dup());
            ret.setvBias(this.vBias.dup());
            ret.setnHidden(this.getnVisible());
            ret.setnVisible(this.getnHidden());
            ret.setW(this.W.transpose());
            ret.setRng(this.getRng());
            ret.setDist(this.getDist());
            return ret;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public NeuralNetwork clone() {
        try {
            NeuralNetwork ret = (NeuralNetwork)this.getClass().newInstance();
            ret.sethBias(this.hBias.dup());
            ret.setvBias(this.vBias.dup());
            ret.setnHidden(this.getnHidden());
            ret.setnVisible(this.getnVisible());
            ret.setW(this.W.dup());
            ret.setRng(this.getRng());
            ret.setDist(this.getDist());
            return ret;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public RealDistribution getDist() {
        return this.dist;
    }

    @Override
    public void setDist(RealDistribution dist) {
        this.dist = dist;
    }

    @Override
    public void merge(NeuralNetwork network, int batchSize) {
        this.W.addi(network.getW().mini(this.W).div((double)batchSize));
        this.hBias.addi(network.gethBias().subi(this.hBias).divi((double)batchSize));
        this.vBias.addi(network.getvBias().subi(this.vBias).divi((double)batchSize));
    }

    public void update(BaseNeuralNetwork n) {
        this.W = n.W;
        this.hBias = n.hBias;
        this.vBias = n.vBias;
        this.l2 = n.l2;
        this.useRegularization = n.useRegularization;
        this.momentum = n.momentum;
        this.nHidden = n.nHidden;
        this.nVisible = n.nVisible;
        this.rng = n.rng;
        this.sparsity = n.sparsity;
    }

    @Override
    public void load(InputStream is) {
        try {
            ObjectInputStream ois = new ObjectInputStream(is);
            BaseNeuralNetwork loaded = (BaseNeuralNetwork)ois.readObject();
            this.update(loaded);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public double getReConstructionCrossEntropy() {
        DoubleMatrix preSigH = this.input.mmul(this.W).addRowVector(this.hBias);
        DoubleMatrix sigH = MatrixUtil.sigmoid(preSigH);
        DoubleMatrix preSigV = sigH.mmul(this.W.transpose()).addRowVector(this.vBias);
        DoubleMatrix sigV = MatrixUtil.sigmoid(preSigV);
        DoubleMatrix inner = this.input.mul(MatrixUtil.log(sigV)).add(MatrixUtil.oneMinus(this.input).mul(MatrixUtil.log(MatrixUtil.oneMinus(sigV))));
        double l = inner.length;
        if (this.useRegularization) {
            double normalized = l + this.l2RegularizedCoefficient();
            return -inner.rowSums().mean() / normalized;
        }
        double ret = -inner.rowSums().mean();
        return ret;
    }

    @Override
    public int getnVisible() {
        return this.nVisible;
    }

    @Override
    public void setnVisible(int nVisible) {
        this.nVisible = nVisible;
    }

    @Override
    public int getnHidden() {
        return this.nHidden;
    }

    @Override
    public void setnHidden(int nHidden) {
        this.nHidden = nHidden;
    }

    @Override
    public DoubleMatrix getW() {
        return this.W;
    }

    @Override
    public void setW(DoubleMatrix w) {
        this.W = w;
    }

    @Override
    public DoubleMatrix gethBias() {
        return this.hBias;
    }

    @Override
    public void sethBias(DoubleMatrix hBias) {
        this.hBias = hBias;
    }

    @Override
    public DoubleMatrix getvBias() {
        return this.vBias;
    }

    @Override
    public void setvBias(DoubleMatrix vBias) {
        this.vBias = vBias;
    }

    @Override
    public RandomGenerator getRng() {
        return this.rng;
    }

    @Override
    public void setRng(RandomGenerator rng) {
        this.rng = rng;
    }

    @Override
    public DoubleMatrix getInput() {
        return this.input;
    }

    @Override
    public void setInput(DoubleMatrix input) {
        this.input = input;
    }

    @Override
    public double getSparsity() {
        return this.sparsity;
    }

    @Override
    public void setSparsity(double sparsity) {
        this.sparsity = sparsity;
    }

    @Override
    public double getMomentum() {
        return this.momentum;
    }

    @Override
    public void setMomentum(double momentum) {
        this.momentum = momentum;
    }

    @Override
    public double getL2() {
        return this.l2;
    }

    @Override
    public void setL2(double l2) {
        this.l2 = l2;
    }

    @Override
    public void write(OutputStream os) {
        try {
            ObjectOutputStream os2 = new ObjectOutputStream(os);
            os2.writeObject(this);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public abstract DoubleMatrix reconstruct(DoubleMatrix var1);

    public abstract double lossFunction(Object[] var1);

    public double lossFunction() {
        return this.lossFunction(null);
    }

    @Override
    public abstract void train(DoubleMatrix var1, double var2, Object[] var4);

    @Override
    public double squaredLoss() {
        DoubleMatrix reconstructed = this.reconstruct(this.input);
        double loss = MatrixFunctions.powi((DoubleMatrix)reconstructed.sub(this.input), (double)2.0).sum() / (double)this.input.rows;
        if (this.useRegularization) {
            loss += 0.5 * this.l2 * MatrixFunctions.pow((DoubleMatrix)this.W, (double)2.0).sum();
        }
        return -loss;
    }

    public static class Builder<E extends BaseNeuralNetwork> {
        private E ret = null;
        private DoubleMatrix W;
        protected Class<? extends NeuralNetwork> clazz;
        private DoubleMatrix vBias;
        private DoubleMatrix hBias;
        private int numVisible;
        private int numHidden;
        private RandomGenerator gen = new MersenneTwister(123);
        private DoubleMatrix input;
        private double sparsity = 0.01;
        private double l2 = 0.01;
        private double momentum = 0.1;
        private int renderWeightsEveryNumEpochs = -1;
        private double fanIn = 0.1;
        private boolean useRegularization = true;
        private RealDistribution dist;

        public Builder<E> withDistribution(RealDistribution dist) {
            this.dist = dist;
            return this;
        }

        public Builder<E> useRegularization(boolean useRegularization) {
            this.useRegularization = useRegularization;
            return this;
        }

        public Builder<E> fanIn(double fanIn) {
            this.fanIn = fanIn;
            return this;
        }

        public Builder<E> withL2(double l2) {
            this.l2 = l2;
            return this;
        }

        public Builder<E> renderWeights(int numEpochs) {
            this.renderWeightsEveryNumEpochs = numEpochs;
            return this;
        }

        public E buildEmpty() {
            try {
                return (E)((BaseNeuralNetwork)this.clazz.newInstance());
            }
            catch (IllegalAccessException | InstantiationException e) {
                throw new RuntimeException(e);
            }
        }

        public Builder<E> withClazz(Class<? extends BaseNeuralNetwork> clazz) {
            this.clazz = clazz;
            return this;
        }

        public Builder<E> withSparsity(double sparsity) {
            this.sparsity = sparsity;
            return this;
        }

        public Builder<E> withMomentum(double momentum) {
            this.momentum = momentum;
            return this;
        }

        public Builder<E> withInput(DoubleMatrix input) {
            this.input = input;
            return this;
        }

        public Builder<E> asType(Class<E> clazz) {
            this.clazz = clazz;
            return this;
        }

        public Builder<E> withWeights(DoubleMatrix W) {
            this.W = W;
            return this;
        }

        public Builder<E> withVisibleBias(DoubleMatrix vBias) {
            this.vBias = vBias;
            return this;
        }

        public Builder<E> withHBias(DoubleMatrix hBias) {
            this.hBias = hBias;
            return this;
        }

        public Builder<E> numberOfVisible(int numVisible) {
            this.numVisible = numVisible;
            return this;
        }

        public Builder<E> numHidden(int numHidden) {
            this.numHidden = numHidden;
            return this;
        }

        public Builder<E> withRandom(RandomGenerator gen) {
            this.gen = gen;
            return this;
        }

        public E build() {
            return this.buildWithInput();
        }

        private E buildWithInput() {
            Constructor<?>[] c = this.clazz.getDeclaredConstructors();
            for (int i = 0; i < c.length; ++i) {
                Constructor<?> curr = c[i];
                Class<?>[] classes = curr.getParameterTypes();
                if (classes == null || classes.length <= 0 || !classes[0].isAssignableFrom(DoubleMatrix.class)) continue;
                try {
                    this.ret = (BaseNeuralNetwork)curr.newInstance(this.input, this.numVisible, this.numHidden, this.W, this.hBias, this.vBias, this.gen, this.fanIn, this.dist);
                    ((BaseNeuralNetwork)this.ret).sparsity = this.sparsity;
                    ((BaseNeuralNetwork)this.ret).renderWeightsEveryNumEpochs = this.renderWeightsEveryNumEpochs;
                    ((BaseNeuralNetwork)this.ret).l2 = this.l2;
                    ((BaseNeuralNetwork)this.ret).momentum = this.momentum;
                    ((BaseNeuralNetwork)this.ret).useRegularization = this.useRegularization;
                    return this.ret;
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            return this.ret;
        }
    }
}

