/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rbm;

import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.rbm.RBM;
import org.deeplearning4j.rbm.RBMOptimizer;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

public class GaussianRectifiedLinearRBM
extends RBM {
    private static final long serialVersionUID = 5186639601076269003L;
    private DoubleMatrix sigma;

    private GaussianRectifiedLinearRBM() {
    }

    private GaussianRectifiedLinearRBM(DoubleMatrix input, int nVisible, int nHidden, DoubleMatrix W, DoubleMatrix hbias, DoubleMatrix vbias, RandomGenerator rng, double fanIn, RealDistribution dist) {
        super(input, nVisible, nHidden, W, hbias, vbias, rng, fanIn, dist);
        if (this.useAdaGrad) {
            this.wAdaGrad.setMasterStepSize(1.0E-4);
            this.wAdaGrad.setDecayLr(true);
        }
        this.sigma = DoubleMatrix.ones((int)nVisible);
        this.applySparsity = false;
    }

    @Override
    public void trainTillConvergence(double learningRate, int k, DoubleMatrix input) {
        if (input != null) {
            this.input = input;
        }
        this.optimizer = new RBMOptimizer(this, learningRate, new Object[]{k, learningRate}, this.optimizationAlgo, this.lossFunction);
        this.optimizer.setTolerance(1.0E-6);
        this.optimizer.train(input);
    }

    @Override
    public DoubleMatrix propUp(DoubleMatrix v) {
        DoubleMatrix preSig = v.divRowVector(this.sigma).mmul(this.W).addiRowVector(this.hBias);
        return preSig;
    }

    @Override
    public Pair<DoubleMatrix, DoubleMatrix> sampleHiddenGivenVisible(DoubleMatrix v) {
        DoubleMatrix h1Mean = this.propUp(v);
        DoubleMatrix sigH1Mean = MatrixUtil.sigmoid(h1Mean);
        DoubleMatrix h1Sample = h1Mean.addi(MatrixUtil.normal(this.getRng(), h1Mean, 1.0).mul(MatrixFunctions.sqrt((DoubleMatrix)sigH1Mean)));
        MatrixUtil.max(0.0, h1Sample);
        this.applyDropOutIfNecessary(h1Sample);
        return new Pair<DoubleMatrix, DoubleMatrix>(h1Mean, h1Sample);
    }

    @Override
    public DoubleMatrix propDown(DoubleMatrix h) {
        DoubleMatrix vMean = h.mmul(this.W.transpose()).mulRowVector(this.vBias.add(this.sigma));
        return vMean;
    }

    @Override
    public Pair<DoubleMatrix, DoubleMatrix> sampleVisibleGivenHidden(DoubleMatrix h) {
        DoubleMatrix v1Mean = this.propDown(h);
        DoubleMatrix v1Sample = MatrixUtil.normal(this.getRng(), v1Mean, 1.0).mulRowVector(this.sigma);
        return new Pair<DoubleMatrix, DoubleMatrix>(v1Mean, v1Sample);
    }

    public static class Builder
    extends BaseNeuralNetwork.Builder<GaussianRectifiedLinearRBM> {
        public Builder() {
            this.clazz = GaussianRectifiedLinearRBM.class;
        }
    }
}

