/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize;

import cc.mallet.optimize.Optimizable;
import java.io.Serializable;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultiLayerNetworkOptimizer
implements Optimizable.ByGradientValue,
Serializable {
    private static final long serialVersionUID = -3012638773299331828L;
    protected BaseMultiLayerNetwork network;
    private static Logger log = LoggerFactory.getLogger(MultiLayerNetworkOptimizer.class);
    private double lr;

    public MultiLayerNetworkOptimizer(BaseMultiLayerNetwork network, double lr) {
        this.network = network;
        this.lr = lr;
    }

    public void optimize(DoubleMatrix labels, double lr, int epochs) {
        this.network.feedForward(this.network.getInput());
        DoubleMatrix layerInput = this.network.getSigmoidLayers()[this.network.getSigmoidLayers().length - 1].sample_h_given_v();
        this.network.getLogLayer().setInput(layerInput);
        this.network.getLogLayer().setLabels(labels);
        if (layerInput.rows != labels.rows) {
            throw new IllegalStateException("Labels not equal to input");
        }
        if (!this.network.isForceNumEpochs()) {
            this.network.getLogLayer().trainTillConvergence(lr, epochs);
        } else {
            log.info("Training for " + epochs + " epochs");
            for (int i = 0; i < epochs; ++i) {
                this.network.getLogLayer().train(layerInput, labels, lr);
            }
        }
        if (this.network.isShouldBackProp()) {
            this.network.backProp(lr, epochs);
        }
    }

    public int getNumParameters() {
        return this.network.getLogLayer().getW().length + this.network.getLogLayer().getB().length;
    }

    public void getParameters(double[] buffer) {
        int i;
        int idx = 0;
        for (i = 0; i < this.network.getLogLayer().getW().length; ++i) {
            buffer[idx++] = this.network.getLogLayer().getW().get(i);
        }
        for (i = 0; i < this.network.getLogLayer().getB().length; ++i) {
            buffer[idx++] = this.network.getLogLayer().getB().get(i);
        }
    }

    public double getParameter(int index) {
        if (index >= this.network.getLogLayer().getW().length) {
            int i = index - this.network.getLogLayer().getB().length;
            return this.network.getLogLayer().getB().get(i);
        }
        return this.network.getLogLayer().getW().get(index);
    }

    public void setParameters(double[] params) {
        int i;
        int idx = 0;
        for (i = 0; i < this.network.getLogLayer().getW().length; ++i) {
            this.network.getLogLayer().getW().put(i, params[idx++]);
        }
        for (i = 0; i < this.network.getLogLayer().getB().length; ++i) {
            this.network.getLogLayer().getB().put(i, params[idx++]);
        }
    }

    public void setParameter(int index, double value) {
        if (index >= this.network.getLogLayer().getW().length) {
            int i = index - this.network.getLogLayer().getB().length;
            this.network.getLogLayer().getB().put(i, value);
        } else {
            this.network.getLogLayer().getW().put(index, value);
        }
    }

    public void getValueGradient(double[] buffer) {
        int i;
        DoubleMatrix p_y_given_x = MatrixUtil.softmax(this.network.getLogLayer().getInput().mmul(this.network.getLogLayer().getW()).addRowVector(this.network.getLogLayer().getB()));
        DoubleMatrix dy = this.network.getLogLayer().getLabels().sub(p_y_given_x);
        int idx = 0;
        DoubleMatrix weightGradient = this.network.getLogLayer().getInput().transpose().mmul(dy).mul(this.lr);
        DoubleMatrix biasGradient = dy.columnMeans().mul(this.lr);
        for (i = 0; i < weightGradient.length; ++i) {
            buffer[idx++] = weightGradient.get(i);
        }
        for (i = 0; i < biasGradient.length; ++i) {
            buffer[idx++] = biasGradient.get(i);
        }
    }

    public double getValue() {
        return this.network.negativeLogLikelihood();
    }
}

