/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rbm;

import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.rbm.RBM;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

public class CRBM
extends RBM {
    private static final long serialVersionUID = 598767790003731193L;

    public CRBM() {
    }

    public CRBM(DoubleMatrix input, int n_visible, int n_hidden, DoubleMatrix W, DoubleMatrix hbias, DoubleMatrix vbias, RandomGenerator rng, double fanIn, RealDistribution dist) {
        super(input, n_visible, n_hidden, W, hbias, vbias, rng, fanIn, dist);
    }

    @Override
    public DoubleMatrix propDown(DoubleMatrix h) {
        return h.mmul(this.W.transpose()).addRowVector(this.vBias);
    }

    @Override
    public Pair<DoubleMatrix, DoubleMatrix> sampleVGivenH(DoubleMatrix h) {
        DoubleMatrix aH = this.propDown(h);
        DoubleMatrix en = MatrixFunctions.exp((DoubleMatrix)aH.neg());
        DoubleMatrix ep = MatrixFunctions.exp((DoubleMatrix)aH);
        DoubleMatrix v1Mean = MatrixUtil.oneDiv(MatrixUtil.oneMinus(en).sub(MatrixUtil.oneDiv(aH)));
        DoubleMatrix v1Sample = MatrixUtil.log(MatrixUtil.oneMinus(MatrixUtil.uniform(this.rng, v1Mean.rows, v1Mean.columns).mul(MatrixUtil.oneMinus(ep)))).div(aH);
        return new Pair<DoubleMatrix, DoubleMatrix>(v1Mean, v1Sample);
    }

    public static class Builder
    extends BaseNeuralNetwork.Builder<CRBM> {
        public Builder() {
            this.clazz = CRBM.class;
        }
    }
}

