/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.da;

import java.io.Serializable;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.sda.DenoisingAutoEncoderOptimizer;
import org.deeplearning4j.util.MathUtils;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

public class DenoisingAutoEncoder
extends BaseNeuralNetwork
implements Serializable {
    private static final long serialVersionUID = -6445530486350763837L;

    public DenoisingAutoEncoder() {
    }

    public DenoisingAutoEncoder(DoubleMatrix input, int nVisible, int nHidden, DoubleMatrix W, DoubleMatrix hbias, DoubleMatrix vbias, RandomGenerator rng, double fanIn, RealDistribution dist) {
        super(input, nVisible, nHidden, W, hbias, vbias, rng, fanIn, dist);
    }

    public DoubleMatrix getCorruptedInput(DoubleMatrix x, double corruptionLevel) {
        DoubleMatrix tilde_x = DoubleMatrix.zeros((int)x.rows, (int)x.columns);
        for (int i = 0; i < x.rows; ++i) {
            for (int j = 0; j < x.columns; ++j) {
                tilde_x.put(i, j, (double)MathUtils.binomial(this.rng, 1, 1.0 - corruptionLevel));
            }
        }
        DoubleMatrix ret = tilde_x.mul(x);
        return ret;
    }

    public double negativeLoglikelihood(double corruptionLevel) {
        DoubleMatrix corrupted = this.getCorruptedInput(this.input, corruptionLevel);
        DoubleMatrix y = this.getHiddenValues(corrupted);
        DoubleMatrix z = this.getReconstructedInput(y);
        if (this.useRegularization) {
            double reg = 2.0 / this.l2 * MatrixFunctions.pow((DoubleMatrix)this.W, (double)2.0).sum();
            return -this.input.mul(MatrixUtil.log(z)).add(MatrixUtil.oneMinus(this.input).mul(MatrixUtil.log(MatrixUtil.oneMinus(z)))).columnSums().mean() + reg;
        }
        return -this.input.mul(MatrixUtil.log(z)).add(MatrixUtil.oneMinus(this.input).mul(MatrixUtil.log(MatrixUtil.oneMinus(z)))).columnSums().mean();
    }

    public DoubleMatrix getHiddenValues(DoubleMatrix x) {
        return MatrixUtil.sigmoid(x.mmul(this.W).addRowVector(this.hBias));
    }

    public DoubleMatrix getReconstructedInput(DoubleMatrix y) {
        return MatrixUtil.sigmoid(y.mmul(this.W.transpose()).addRowVector(this.vBias));
    }

    public void trainTillConvergence(DoubleMatrix x, double lr, double corruptionLevel) {
        if (x != null) {
            this.input = x;
        }
        this.optimizer = new DenoisingAutoEncoderOptimizer(this, lr, new Object[]{corruptionLevel, lr});
        this.optimizer.train(x);
    }

    public void train(DoubleMatrix x, double lr, double corruptionLevel) {
        this.input = x;
        NeuralNetworkGradient gradient = this.getGradient(new Object[]{corruptionLevel, lr});
        this.vBias.addi(gradient.getvBiasGradient());
        this.W.addi(gradient.getwGradient());
        this.hBias.addi(gradient.gethBiasGradient());
    }

    @Override
    public DoubleMatrix reconstruct(DoubleMatrix x) {
        DoubleMatrix y = this.getHiddenValues(x);
        return this.getReconstructedInput(y);
    }

    @Override
    public void trainTillConvergence(DoubleMatrix input, double lr, Object[] params) {
        if (input != null) {
            this.input = input;
        }
        this.optimizer = new DenoisingAutoEncoderOptimizer(this, lr, params);
        this.optimizer.train(input);
    }

    @Override
    public double lossFunction(Object[] params) {
        double corruptionLevel = (Double)params[0];
        return this.negativeLoglikelihood(corruptionLevel);
    }

    @Override
    public void train(DoubleMatrix input, double lr, Object[] params) {
        double corruptionLevel = (Double)params[0];
        this.train(input, lr, corruptionLevel);
    }

    @Override
    public synchronized NeuralNetworkGradient getGradient(Object[] params) {
        double corruptionLevel = (Double)params[0];
        double lr = (Double)params[1];
        DoubleMatrix tildeX = this.getCorruptedInput(this.input, corruptionLevel);
        DoubleMatrix y = this.getHiddenValues(tildeX);
        DoubleMatrix z = this.getReconstructedInput(y);
        DoubleMatrix L_h2 = this.input.sub(z);
        DoubleMatrix L_h1 = this.sparsity == 0.0 ? L_h2.mmul(this.W).mul(y).mul(MatrixUtil.oneMinus(y)) : L_h2.mmul(this.W).mul(y).mul(y.add(-this.sparsity));
        DoubleMatrix L_vbias = L_h2;
        DoubleMatrix L_hbias = L_h1;
        DoubleMatrix L_W = tildeX.transpose().mmul(L_h1).add(L_h2.transpose().mmul(y));
        L_W.muli(lr);
        if (this.useRegularization) {
            L_W.subi(this.W.muli(this.l2));
        }
        if (this.momentum != 0.0) {
            L_W.muli(1.0 - this.momentum);
        }
        L_W.divi((double)this.input.rows);
        DoubleMatrix L_hbias_mean = L_hbias.columnMeans();
        DoubleMatrix L_vbias_mean = L_vbias.columnMeans();
        return new NeuralNetworkGradient(L_W, L_vbias_mean, L_hbias_mean);
    }

    public static class Builder
    extends BaseNeuralNetwork.Builder<DenoisingAutoEncoder> {
        public Builder() {
            this.clazz = DenoisingAutoEncoder.class;
        }
    }
}

