/*
 * Decompiled with CFR 0.152.
 */
package smile.base.mlp;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.Arrays;
import smile.base.mlp.ActivationFunction;
import smile.base.mlp.Cost;
import smile.base.mlp.HiddenLayerBuilder;
import smile.base.mlp.OutputFunction;
import smile.base.mlp.OutputLayerBuilder;
import smile.math.MathEx;
import smile.math.matrix.Matrix;

public abstract class Layer
implements Serializable {
    private static final long serialVersionUID = 2L;
    protected int n;
    protected int p;
    protected Matrix weight;
    protected double[] bias;
    protected transient ThreadLocal<double[]> output;
    protected transient ThreadLocal<double[]> outputGradient;
    protected transient ThreadLocal<Matrix> weightGradient;
    protected transient ThreadLocal<double[]> biasGradient;
    protected transient ThreadLocal<Matrix> rmsWeightGradient;
    protected transient ThreadLocal<double[]> rmsBiasGradient;
    protected transient ThreadLocal<Matrix> weightUpdate;
    protected transient ThreadLocal<double[]> biasUpdate;

    public Layer(int n, int p) {
        this(Matrix.rand((int)n, (int)p, (double)(-Math.sqrt(6.0 / (double)(n + p))), (double)Math.sqrt(6.0 / (double)(n + p))), new double[n]);
    }

    public Layer(Matrix weight, double[] bias) {
        this.n = weight.nrows();
        this.p = weight.ncols();
        this.weight = weight;
        this.bias = bias;
        this.init();
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.init();
    }

    private void init() {
        this.output = new ThreadLocal<double[]>(){

            @Override
            protected synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
        this.outputGradient = new ThreadLocal<double[]>(){

            @Override
            protected synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
        this.weightGradient = new ThreadLocal<Matrix>(){

            @Override
            protected synchronized Matrix initialValue() {
                return new Matrix(Layer.this.n, Layer.this.p);
            }
        };
        this.biasGradient = new ThreadLocal<double[]>(){

            @Override
            protected synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
        this.rmsWeightGradient = new ThreadLocal<Matrix>(){

            @Override
            protected synchronized Matrix initialValue() {
                return new Matrix(Layer.this.n, Layer.this.p);
            }
        };
        this.rmsBiasGradient = new ThreadLocal<double[]>(){

            @Override
            protected synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
        this.weightUpdate = new ThreadLocal<Matrix>(){

            @Override
            protected synchronized Matrix initialValue() {
                return new Matrix(Layer.this.n, Layer.this.p);
            }
        };
        this.biasUpdate = new ThreadLocal<double[]>(){

            @Override
            protected synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
    }

    public int getOutputSize() {
        return this.n;
    }

    public int getInputSize() {
        return this.p;
    }

    public double[] output() {
        return this.output.get();
    }

    public double[] gradient() {
        return this.outputGradient.get();
    }

    public void propagate(double[] x) {
        double[] output = this.output.get();
        System.arraycopy(this.bias, 0, output, 0, this.n);
        this.weight.mv(1.0, x, 1.0, output);
        this.f(output);
    }

    public abstract void f(double[] var1);

    public abstract void backpropagate(double[] var1);

    public void computeGradientUpdate(double[] x, double learningRate, double momentum, double decay) {
        double[] outputGradient = this.outputGradient.get();
        if (momentum > 0.0 && momentum < 1.0) {
            Matrix weightUpdate = this.weightUpdate.get();
            double[] biasUpdate = this.biasUpdate.get();
            weightUpdate.mul(momentum);
            weightUpdate.add(learningRate, outputGradient, x);
            this.weight.add(1.0, weightUpdate);
            int i = 0;
            while (i < this.n) {
                double b;
                biasUpdate[i] = b = momentum * biasUpdate[i] + learningRate * outputGradient[i];
                int n = i++;
                this.bias[n] = this.bias[n] + b;
            }
        } else {
            this.weight.add(learningRate, outputGradient, x);
            for (int i = 0; i < this.n; ++i) {
                int n = i;
                this.bias[n] = this.bias[n] + learningRate * outputGradient[i];
            }
        }
        if (decay > 0.9 && decay < 1.0) {
            this.weight.mul(decay);
        }
    }

    public void computeGradient(double[] x) {
        double[] outputGradient = this.outputGradient.get();
        Matrix weightGradient = this.weightGradient.get();
        double[] biasGradient = this.biasGradient.get();
        weightGradient.add(1.0, outputGradient, x);
        for (int i = 0; i < this.n; ++i) {
            int n = i;
            biasGradient[n] = biasGradient[n] + outputGradient[i];
        }
    }

    public void update(int m, double learningRate, double momentum, double decay, double rho, double epsilon) {
        Matrix weightGradient = this.weightGradient.get();
        double[] biasGradient = this.biasGradient.get();
        double eta = learningRate / (double)m;
        if (rho > 0.0 && rho < 1.0) {
            int i;
            int i2;
            int j;
            eta = learningRate;
            weightGradient.div((double)m);
            int i3 = 0;
            while (i3 < this.n) {
                int n = i3++;
                biasGradient[n] = biasGradient[n] / (double)m;
            }
            Matrix rmsWeightGradient = this.rmsWeightGradient.get();
            double[] rmsBiasGradient = this.rmsBiasGradient.get();
            double rho1 = 1.0 - rho;
            for (j = 0; j < this.p; ++j) {
                for (i2 = 0; i2 < this.n; ++i2) {
                    rmsWeightGradient.set(i2, j, rho * rmsWeightGradient.get(i2, j) + rho1 * MathEx.sqr((double)weightGradient.get(i2, j)));
                }
            }
            for (i = 0; i < this.n; ++i) {
                rmsBiasGradient[i] = rho * rmsBiasGradient[i] + rho1 * MathEx.sqr((double)biasGradient[i]);
            }
            for (j = 0; j < this.p; ++j) {
                for (i2 = 0; i2 < this.n; ++i2) {
                    weightGradient.div(i2, j, Math.sqrt(epsilon + rmsWeightGradient.get(i2, j)));
                }
            }
            for (i = 0; i < this.n; ++i) {
                int n = i;
                biasGradient[n] = biasGradient[n] / Math.sqrt(epsilon + rmsBiasGradient[i]);
            }
        }
        if (momentum > 0.0 && momentum < 1.0) {
            Matrix weightUpdate = this.weightUpdate.get();
            double[] biasUpdate = this.biasUpdate.get();
            weightUpdate.add(momentum, eta, weightGradient);
            for (int i = 0; i < this.n; ++i) {
                biasUpdate[i] = momentum * biasUpdate[i] + eta * biasGradient[i];
            }
            this.weight.add(1.0, weightUpdate);
            MathEx.add((double[])this.bias, (double[])biasUpdate);
        } else {
            this.weight.add(eta, weightGradient);
            for (int i = 0; i < this.n; ++i) {
                int n = i;
                this.bias[n] = this.bias[n] + eta * biasGradient[i];
            }
        }
        if (decay > 0.9 && decay < 1.0) {
            this.weight.mul(decay);
        }
        weightGradient.fill(0.0);
        Arrays.fill(biasGradient, 0.0);
    }

    public static HiddenLayerBuilder linear(int n) {
        return new HiddenLayerBuilder(n, ActivationFunction.linear());
    }

    public static HiddenLayerBuilder rectifier(int n) {
        return new HiddenLayerBuilder(n, ActivationFunction.rectifier());
    }

    public static HiddenLayerBuilder sigmoid(int n) {
        return new HiddenLayerBuilder(n, ActivationFunction.sigmoid());
    }

    public static HiddenLayerBuilder tanh(int n) {
        return new HiddenLayerBuilder(n, ActivationFunction.tanh());
    }

    public static OutputLayerBuilder mse(int n, OutputFunction f) {
        return new OutputLayerBuilder(n, f, Cost.MEAN_SQUARED_ERROR);
    }

    public static OutputLayerBuilder mle(int n, OutputFunction f) {
        return new OutputLayerBuilder(n, f, Cost.LIKELIHOOD);
    }
}

