/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.feedforward.rbm;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.RBM;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BasePretrainNetwork;
import org.deeplearning4j.util.Dropout;
import org.deeplearning4j.util.RBMUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class RBM
extends BasePretrainNetwork<org.deeplearning4j.nn.conf.layers.RBM> {
    private transient Random rng = Nd4j.getRandom();
    protected INDArray sigma;
    protected INDArray hiddenSigma;

    public RBM(NeuralNetConfiguration conf) {
        super(conf);
    }

    public RBM(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    public void contrastiveDivergence() {
        Gradient gradient = this.gradient();
        this.getParam("bB").subi(gradient.gradientForVariable().get("bB"));
        this.getParam("b").subi(gradient.gradientForVariable().get("b"));
        this.getParam("W").subi(gradient.gradientForVariable().get("W"));
    }

    @Override
    public void computeGradientAndScore() {
        int k = ((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getK();
        Pair<INDArray, INDArray> probHidden = this.sampleHiddenGivenVisible(this.input());
        INDArray chainStart = probHidden.getSecond();
        INDArray nvMeans = null;
        INDArray nvSamples = null;
        INDArray nhMeans = null;
        INDArray nhSamples = null;
        for (int i = 0; i < k; ++i) {
            Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> matrices = i == 0 ? this.gibbhVh(chainStart) : this.gibbhVh(nhSamples);
            nvMeans = matrices.getFirst().getFirst();
            nvSamples = matrices.getFirst().getSecond();
            nhMeans = matrices.getSecond().getFirst();
            nhSamples = matrices.getSecond().getSecond();
        }
        INDArray wGradient = this.input().transposei().mmul(probHidden.getSecond()).subi(nvSamples.transpose().mmul(nhMeans));
        INDArray hBiasGradient = ((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getSparsity() != 0.0 ? probHidden.getSecond().rsub((Number)((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getSparsity()).sum(new int[]{0}) : probHidden.getSecond().sub(nhMeans).sum(new int[]{0});
        INDArray delta = this.input.sub(nvSamples);
        INDArray vBiasGradient = delta.sum(new int[]{0});
        DefaultGradient ret = new DefaultGradient();
        ret.gradientForVariable().put("bB", vBiasGradient);
        ret.gradientForVariable().put("b", hBiasGradient);
        ret.gradientForVariable().put("W", wGradient);
        this.gradient = ret;
        this.setScoreWithZ(delta);
    }

    @Override
    public Layer transpose() {
        RBM r = (RBM)super.transpose();
        RBM.HiddenUnit h = RBMUtil.inverse(((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getVisibleUnit());
        RBM.VisibleUnit v = RBMUtil.inverse(((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getHiddenUnit());
        if (h == null) {
            h = ((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getHiddenUnit();
        }
        if (v == null) {
            v = ((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getVisibleUnit();
        }
        ((org.deeplearning4j.nn.conf.layers.RBM)r.layerConf()).setHiddenUnit(h);
        ((org.deeplearning4j.nn.conf.layers.RBM)r.layerConf()).setVisibleUnit(v);
        INDArray vb = this.getParam("b").dup();
        INDArray b = this.getParam("bB").dup();
        r.setParam("bB", vb);
        r.setParam("b", b);
        r.sigma = this.sigma;
        r.hiddenSigma = this.hiddenSigma;
        return r;
    }

    @Override
    public Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray v) {
        INDArray h1Sample;
        INDArray h1Mean = this.propUp(v);
        switch (((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getHiddenUnit()) {
            case RECTIFIED: {
                INDArray sigH1Mean = Transforms.sigmoid((INDArray)h1Mean);
                INDArray sqrtSigH1Mean = Transforms.sqrt((INDArray)sigH1Mean);
                INDArray sample = Nd4j.getDistributions().createNormal(h1Mean, 1.0).sample(h1Mean.shape());
                sample.muli(sqrtSigH1Mean);
                h1Sample = h1Mean.add(sample);
                h1Sample = Transforms.max((INDArray)h1Sample, (double)0.0);
                break;
            }
            case GAUSSIAN: {
                h1Sample = h1Mean.add(Nd4j.randn((int)h1Mean.rows(), (int)h1Mean.columns(), (Random)this.rng));
                break;
            }
            case SOFTMAX: {
                h1Sample = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", h1Mean));
                break;
            }
            case BINARY: {
                h1Sample = Nd4j.getDistributions().createBinomial(1, h1Mean).sample(h1Mean.shape());
                break;
            }
            default: {
                throw new IllegalStateException("Hidden unit type must either be rectified linear or binary");
            }
        }
        return new Pair<INDArray, INDArray>(h1Mean, h1Sample);
    }

    public Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> gibbhVh(INDArray h) {
        Pair<INDArray, INDArray> v1MeanAndSample = this.sampleVisibleGivenHidden(h);
        INDArray vSample = v1MeanAndSample.getSecond();
        Pair<INDArray, INDArray> h1MeanAndSample = this.sampleHiddenGivenVisible(vSample);
        return new Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>(v1MeanAndSample, h1MeanAndSample);
    }

    @Override
    public Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray h) {
        INDArray v1Sample;
        INDArray v1Mean = this.propDown(h);
        switch (((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getVisibleUnit()) {
            case GAUSSIAN: {
                v1Sample = v1Mean.add(Nd4j.randn((int)v1Mean.rows(), (int)v1Mean.columns(), (Random)this.rng));
                break;
            }
            case LINEAR: {
                v1Sample = Nd4j.getDistributions().createNormal(v1Mean, 1.0).sample(v1Mean.shape());
                break;
            }
            case SOFTMAX: {
                v1Sample = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", v1Mean));
                break;
            }
            case BINARY: {
                v1Sample = Nd4j.getDistributions().createBinomial(1, v1Mean).sample(v1Mean.shape());
                break;
            }
            default: {
                throw new IllegalStateException("Visible type must be one of Binary, Gaussian, SoftMax or Linear");
            }
        }
        return new Pair<INDArray, INDArray>(v1Mean, v1Sample);
    }

    public INDArray propUp(INDArray v, boolean training) {
        INDArray W = this.getParam("W");
        if (training && this.conf.isUseDropConnect() && this.conf.getLayer().getDropOut() > 0.0) {
            W = Dropout.applyDropConnect(this, "W");
        }
        INDArray hBias = this.getParam("b");
        if (((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getVisibleUnit() == RBM.VisibleUnit.GAUSSIAN) {
            this.sigma = v.var(new int[]{0}).divi((Number)this.input.rows());
        }
        INDArray preSig = v.mmul(W).addiRowVector(hBias);
        switch (((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getHiddenUnit()) {
            case RECTIFIED: {
                preSig = Transforms.max((INDArray)preSig, (double)0.0);
                return preSig;
            }
            case GAUSSIAN: {
                preSig.addi(Nd4j.randn((int)preSig.rows(), (int)preSig.columns(), (Random)this.rng));
                return preSig;
            }
            case BINARY: {
                return Transforms.sigmoid((INDArray)preSig);
            }
            case SOFTMAX: {
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", preSig));
            }
        }
        throw new IllegalStateException("Hidden unit type should either be binary, gaussian, or rectified linear");
    }

    public INDArray propUp(INDArray v) {
        return this.propUp(v, true);
    }

    public INDArray propDown(INDArray h) {
        INDArray W = this.getParam("W").transpose();
        INDArray vBias = this.getParam("bB");
        INDArray vMean = h.mmul(W).addiRowVector(vBias);
        switch (((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getVisibleUnit()) {
            case GAUSSIAN: {
                INDArray sample = Nd4j.getDistributions().createNormal(vMean, 1.0).sample(vMean.shape());
                vMean.addi(sample);
                return vMean;
            }
            case LINEAR: {
                return vMean;
            }
            case BINARY: {
                return Transforms.sigmoid((INDArray)vMean);
            }
            case SOFTMAX: {
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", vMean));
            }
        }
        throw new IllegalStateException("Visible unit type should either be binary or gaussian");
    }

    @Override
    public INDArray activate(boolean training) {
        if (training && this.conf.getLayer().getDropOut() > 0.0) {
            Dropout.applyDropout(this.input, this.conf.getLayer().getDropOut());
        }
        INDArray propUp = this.propUp(this.input, training);
        return propUp;
    }

    @Override
    public void fit(INDArray input) {
        if (((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getVisibleUnit() == RBM.VisibleUnit.GAUSSIAN) {
            this.sigma = input.var(new int[]{0});
            this.sigma.divi((Number)input.rows());
        }
        super.fit(input);
    }

    @Override
    public void iterate(INDArray input) {
        if (((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getVisibleUnit() == RBM.VisibleUnit.GAUSSIAN) {
            this.sigma = input.var(new int[]{0}).divi((Number)input.rows());
        }
        this.input = input.dup();
        this.applyDropOutIfNecessary(true);
        this.contrastiveDivergence();
    }
}

