/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize;

import cc.mallet.optimize.Optimizable;
import java.io.Serializable;
import org.deeplearning4j.nn.LogisticRegression;
import org.deeplearning4j.nn.gradient.LogisticRegressionGradient;

public class LogisticRegressionOptimizer
implements Optimizable.ByGradientValue,
Serializable {
    private static final long serialVersionUID = 5229426347154854746L;
    private LogisticRegression logReg;
    private double lr;

    public LogisticRegressionOptimizer(LogisticRegression logReg, double lr) {
        this.logReg = logReg;
        this.lr = lr;
    }

    public int getNumParameters() {
        return this.logReg.getW().length + this.logReg.getB().length;
    }

    public void getParameters(double[] buffer) {
        for (int i = 0; i < buffer.length; ++i) {
            buffer[i] = this.getParameter(i);
        }
    }

    public double getParameter(int index) {
        if (index >= this.logReg.getW().length) {
            return this.logReg.getB().get(index - this.logReg.getW().length);
        }
        return this.logReg.getW().get(index);
    }

    public void setParameters(double[] params) {
        for (int i = 0; i < params.length; ++i) {
            this.setParameter(i, params[i]);
        }
    }

    public void setParameter(int index, double value) {
        if (index >= this.logReg.getW().length) {
            this.logReg.getB().put(index - this.logReg.getW().length, value);
        } else {
            this.logReg.getW().put(index, value);
        }
    }

    public void getValueGradient(double[] buffer) {
        LogisticRegressionGradient grad = this.logReg.getGradient(this.lr);
        for (int i = 0; i < buffer.length; ++i) {
            buffer[i] = i < this.logReg.getW().length ? grad.getwGradient().get(i) : grad.getbGradient().get(i - this.logReg.getW().length);
        }
    }

    public double getValue() {
        return -this.logReg.negativeLogLikelihood();
    }
}

