/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rbm;

import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.nn.NeuralNetwork;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.optimize.NeuralNetworkOptimizer;
import org.jblas.DoubleMatrix;

public class RBMOptimizer
extends NeuralNetworkOptimizer {
    private static final long serialVersionUID = 3676032651650426749L;
    protected int k = -1;
    protected int numTimesIterated = 0;

    public RBMOptimizer(BaseNeuralNetwork network, double lr, Object[] trainingParams, NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm, NeuralNetwork.LossFunction lossFunction) {
        super(network, lr, trainingParams, optimizationAlgorithm, lossFunction);
    }

    @Override
    public void getValueGradient(double[] buffer) {
        int i;
        int k = (Integer)this.extraParams[0];
        ++this.numTimesIterated;
        if (this.k <= 0) {
            this.k = k;
        }
        if (this.numTimesIterated % 10 == 0) {
            ++this.k;
        }
        if (this.k >= 15) {
            this.k = 15;
        }
        k = this.k;
        NeuralNetworkGradient gradient = this.network.getGradient(new Object[]{k, this.lr});
        DoubleMatrix wAdd = gradient.getwGradient();
        DoubleMatrix vBiasAdd = gradient.getvBiasGradient();
        DoubleMatrix hBiasAdd = gradient.gethBiasGradient();
        int idx = 0;
        for (i = 0; i < wAdd.length; ++i) {
            buffer[idx++] = wAdd.get(i);
        }
        for (i = 0; i < vBiasAdd.length; ++i) {
            buffer[idx++] = vBiasAdd.get(i);
        }
        for (i = 0; i < hBiasAdd.length; ++i) {
            buffer[idx++] = hBiasAdd.get(i);
        }
    }
}

