/*
 * Decompiled with CFR 0.152.
 */
package hex.optimization;

import hex.optimization.OptimizationUtils;
import java.util.Arrays;
import java.util.Random;
import water.Iced;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.MathUtils;

public final class L_BFGS
extends Iced {
    int _maxIter = 500;
    double _gradEps = 1.0E-8;
    double _objEps = 1.0E-10;
    int _historySz = 20;
    History _hist;

    public L_BFGS setMaxIter(int m) {
        this._maxIter = m;
        return this;
    }

    public L_BFGS setGradEps(double d) {
        this._gradEps = d;
        return this;
    }

    public L_BFGS setObjEps(double d) {
        this._objEps = d;
        return this;
    }

    public L_BFGS setHistorySz(int sz) {
        this._historySz = sz;
        return this;
    }

    public int k() {
        return this._hist._k;
    }

    public int maxIter() {
        return this._maxIter;
    }

    public final Result solve(OptimizationUtils.GradientSolver gslvr, double[] beta, OptimizationUtils.GradientInfo ginfo, ProgressMonitor pm) {
        int iter2;
        if (this._hist == null) {
            this._hist = new History(this._historySz, beta.length);
        }
        double rel_improvement = 1.0;
        double[] pk = new double[beta.length];
        double minStep = 1.0E-16;
        OptimizationUtils.MoreThuente lineSearch = new OptimizationUtils.MoreThuente(gslvr, beta, ginfo);
        for (iter2 = 0; !ArrayUtils.hasNaNsOrInfs(beta) && ArrayUtils.linfnorm(ginfo._gradient, false) > this._gradEps && rel_improvement > this._objEps && iter2 != this._maxIter; ++iter2) {
            this._hist.getSearchDirection(ginfo._gradient, pk);
            if (!lineSearch.evaluate(pk)) break;
            lineSearch.setInitialStep(Math.max(minStep, lineSearch.step()));
            OptimizationUtils.GradientInfo newGinfo = lineSearch.ginfo();
            this._hist.update(pk, newGinfo._gradient, ginfo._gradient);
            rel_improvement = (ginfo._objVal - newGinfo._objVal) / Math.abs(ginfo._objVal);
            ginfo = newGinfo;
            if (pm.progress(lineSearch.getX(), ginfo)) continue;
            break;
        }
        return new Result(ArrayUtils.linfnorm(ginfo._gradient, false) <= this._gradEps || rel_improvement <= this._objEps, iter2, lineSearch.getX(), lineSearch.ginfo(), rel_improvement);
    }

    public final Result solve(OptimizationUtils.GradientSolver gslvr, double[] coefs) {
        return this.solve(gslvr, coefs, gslvr.getGradient(coefs), new ProgressMonitor(){

            @Override
            public boolean progress(double[] beta, OptimizationUtils.GradientInfo ginfo) {
                return true;
            }
        });
    }

    public static double[] startCoefs(int n, long seed) {
        double[] res = MemoryManager.malloc8d(n);
        Random r = new Random(seed);
        for (int i = 0; i < res.length; ++i) {
            res[i] = r.nextGaussian();
        }
        return res;
    }

    public static final class History
    extends Iced {
        private final double[][] _s;
        private final double[][] _y;
        private final double[] _rho;
        private final double[] _alpha;
        final int _m;
        final int _n;
        int _k;

        public History(int m, int n) {
            this._m = m;
            this._alpha = new double[this._m];
            this._n = n;
            this._s = new double[m][];
            this._y = new double[m][];
            this._rho = MemoryManager.malloc8d(m);
            Arrays.fill(this._rho, Double.NaN);
            for (int i = 0; i < m; ++i) {
                this._s[i] = MemoryManager.malloc8d(n);
                Arrays.fill(this._s[i], Double.NaN);
                this._y[i] = MemoryManager.malloc8d(n);
                Arrays.fill(this._y[i], Double.NaN);
            }
        }

        int getId(int k) {
            return (this._k + k) % this._m;
        }

        private final void update(double[] pk, double[] gNew, double[] gOld) {
            int id = this.getId(0);
            double[] y = this._y[id];
            double[] s = this._s[id];
            for (int i = 0; i < gNew.length; ++i) {
                y[i] = gNew[i] - gOld[i];
            }
            System.arraycopy(pk, 0, s, 0, pk.length);
            this._rho[id] = 1.0 / ArrayUtils.innerProduct(s, y);
            ++this._k;
        }

        protected final double[] getSearchDirection(double[] gradient, double[] q) {
            System.arraycopy(gradient, 0, q, 0, q.length);
            if (this._k != 0) {
                int k = Math.min(this._k, this._m);
                for (int i = 1; i <= k; ++i) {
                    int id = this.getId(-i);
                    this._alpha[id] = this._rho[id] * ArrayUtils.innerProduct(this._s[id], q);
                    MathUtils.wadd(q, this._y[id], -this._alpha[id]);
                }
                int lastId = this.getId(-1);
                double[] y = this._y[lastId];
                double Hk0 = -1.0 / (ArrayUtils.innerProduct(y, y) * this._rho[lastId]);
                ArrayUtils.mult(q, Hk0);
                for (int i = k; i > 0; --i) {
                    int id = this.getId(-i);
                    double beta = this._rho[id] * ArrayUtils.innerProduct(this._y[id], q);
                    MathUtils.wadd(q, this._s[id], -this._alpha[id] - beta);
                }
            } else {
                ArrayUtils.mult(q, -1.0);
            }
            return q;
        }
    }

    public static final class Result {
        public final int iter;
        public final double[] coefs;
        public final OptimizationUtils.GradientInfo ginfo;
        public final boolean converged;
        public final double rel_improvement;

        public Result(boolean converged, int iter2, double[] coefs, OptimizationUtils.GradientInfo ginfo, double rel_improvement) {
            this.iter = iter2;
            this.coefs = coefs;
            this.ginfo = ginfo;
            this.converged = converged;
            this.rel_improvement = rel_improvement;
        }

        public String toString() {
            if (this.coefs.length < 10) {
                return "L-BFGS_res(converged? " + this.converged + ", iter = " + this.iter + ", obj = " + this.ginfo._objVal + ", rel_improvement = " + this.rel_improvement + ", coefs = " + Arrays.toString(this.coefs) + ", grad = " + Arrays.toString(this.ginfo._gradient) + ")";
            }
            return "L-BFGS_res(converged? " + this.converged + ", iter = " + this.iter + ", obj = " + this.ginfo._objVal + ", rel_improvement = " + this.rel_improvement + "grad_linf_norm = " + ArrayUtils.linfnorm(this.ginfo._gradient, false) + ")";
        }
    }

    public static interface ProgressMonitor {
        public boolean progress(double[] var1, OptimizationUtils.GradientInfo var2);
    }
}

