/*
 * Decompiled with CFR 0.152.
 */
package hex.optimization;

import hex.glm.ComputationState;
import hex.glm.ConstrainedGLMUtils;
import hex.glm.GLM;
import java.util.Arrays;
import java.util.List;
import water.Iced;
import water.util.ArrayUtils;
import water.util.Log;

public class OptimizationUtils {

    public static final class ExactLineSearch {
        public final double _betaLS1 = 1.0E-4;
        public final double _betaLS2 = 0.99;
        public final double _lambdaLS = 2.0;
        public double _alphal;
        public double _alphar;
        public double _alphai;
        public double[] _direction;
        public int _maxIteration = 50;
        public double[] _originalBeta;
        public double[] _newBeta;
        public GLM.GLMGradientInfo _ginfoOriginal;
        public double _currGradDirIP;
        public String[] _coeffNames;

        public ExactLineSearch(double[] betaCnd, ComputationState state, List<String> coeffNames) {
            this.reset(betaCnd, state, coeffNames);
        }

        public void reset(double[] betaCnd, ComputationState state, List<String> coeffNames) {
            this._direction = new double[betaCnd.length];
            ArrayUtils.subtract(betaCnd, state.beta(), this._direction);
            this._ginfoOriginal = state.ginfo();
            this._originalBeta = state.beta();
            this._alphai = 1.0;
            this._alphal = 0.0;
            this._alphar = Double.POSITIVE_INFINITY;
            this._coeffNames = coeffNames.toArray(new String[0]);
            this._currGradDirIP = ArrayUtils.innerProduct(this._ginfoOriginal._gradient, this._direction);
        }

        public boolean evaluateFirstWolfe(GLM.GLMGradientInfo ginfoNew) {
            double newObj = ginfoNew._objVal;
            double rhs = this._ginfoOriginal._objVal + this._alphai * 1.0E-4 * this._currGradDirIP;
            return newObj <= rhs;
        }

        public boolean evaluateSecondWolfe(GLM.GLMGradientInfo ginfo) {
            double lhs = ArrayUtils.innerProduct(ginfo._gradient, this._direction);
            return lhs >= 0.99 * this._currGradDirIP;
        }

        public boolean setAlphai(boolean firstWolfe, boolean secondWolfe) {
            if (!firstWolfe && secondWolfe) {
                this._alphar = this._alphai;
                this._alphai = 0.5 * (this._alphal + this._alphar);
                return true;
            }
            if (firstWolfe && !secondWolfe) {
                this._alphal = this._alphai;
                this._alphai = this._alphar < Double.POSITIVE_INFINITY ? 0.5 * (this._alphal + this._alphar) : 2.0 * this._alphai;
                return true;
            }
            return false;
        }

        public void setBetaConstraintsDeriv(double[] lambdaEqual, double[] lambdaLessThan, ComputationState state, ConstrainedGLMUtils.LinearConstraints[] equalityConstraints, ConstrainedGLMUtils.LinearConstraints[] lessThanEqualToConstraints, GLM.GLMGradientSolver gradientSolver, double[] betaCnd) {
            this._newBeta = betaCnd;
            ConstrainedGLMUtils.updateConstraintValues(betaCnd, Arrays.asList(this._coeffNames), equalityConstraints, lessThanEqualToConstraints);
            ConstrainedGLMUtils.calculateConstraintSquare(state, equalityConstraints, lessThanEqualToConstraints);
            state.updateConstraintInfo(equalityConstraints, lessThanEqualToConstraints);
            this._ginfoOriginal = ConstrainedGLMUtils.calGradient(betaCnd, state, gradientSolver, lambdaEqual, lambdaLessThan, equalityConstraints, lessThanEqualToConstraints);
        }

        public boolean findAlpha(double[] lambdaEqual, double[] lambdaLessThan, ComputationState state, ConstrainedGLMUtils.LinearConstraints[] equalityConstraints, ConstrainedGLMUtils.LinearConstraints[] lessThanEqualToConstraints, GLM.GLMGradientSolver gradientSolver) {
            if (this._currGradDirIP > 0.0) {
                this._newBeta = this._originalBeta;
                return false;
            }
            int betaLen = this._originalBeta.length;
            double[] tempDirection = new double[betaLen];
            for (int index = 0; index < this._maxIteration; ++index) {
                ArrayUtils.mult(this._direction, tempDirection, this._alphai);
                double[] newCoef = ArrayUtils.add(tempDirection, this._originalBeta);
                ConstrainedGLMUtils.updateConstraintValues(newCoef, Arrays.asList(this._coeffNames), equalityConstraints, lessThanEqualToConstraints);
                ConstrainedGLMUtils.calculateConstraintSquare(state, equalityConstraints, lessThanEqualToConstraints);
                state.updateConstraintInfo(equalityConstraints, lessThanEqualToConstraints);
                GLM.GLMGradientInfo newGrad = ConstrainedGLMUtils.calGradient(newCoef, state, gradientSolver, lambdaEqual, lambdaLessThan, equalityConstraints, lessThanEqualToConstraints);
                double gradMagSquare = ArrayUtils.innerProduct(newGrad._gradient, newGrad._gradient);
                boolean gradSmallEnough = gradMagSquare <= state._csGLMState._epsilonkCSSquare;
                boolean firstWolfe = this.evaluateFirstWolfe(newGrad);
                boolean secondWolfe = this.evaluateSecondWolfe(newGrad);
                if (firstWolfe && secondWolfe) {
                    this._newBeta = newCoef;
                    this._ginfoOriginal = newGrad;
                    return true;
                }
                boolean alphaiChange = this.setAlphai(firstWolfe, secondWolfe);
                if (alphaiChange && !(this._alphar < 1.0E-12)) continue;
                if (gradSmallEnough) {
                    this._newBeta = newCoef;
                    this._ginfoOriginal = newGrad;
                }
                return false;
            }
            return false;
        }
    }

    public static final class MoreThuente
    implements LineSearchSolver {
        double _stMin;
        double _stMax;
        double _initialStep = 1.0;
        double _minRelativeImprovement = 1.0E-8;
        private final GradientSolver _gslvr;
        private double[] _beta;
        double _xtol = 1.0E-8;
        double _ftol = 0.1;
        double _gtol = 0.1;
        double _xtrapf = 4.0;
        double _fvx;
        double _dgx;
        double _stx;
        double _bestStep;
        GradientInfo _betGradient;
        double _bestPsiVal;
        GradientInfo _ginfox;
        double _fvy;
        double _dgy;
        double _sty;
        boolean _brackt;
        boolean _bound;
        int _returnStatus;
        public final String[] messages = new String[]{"In progress or not evaluated", "The sufficient decrease condition and the directional derivative condition hold.", "Relative width of the interval of uncertainty is at most xtol.", "Number of calls to gradient solver has reached the limit.", "The step is at the lower bound stpmin.", "The step is at the upper bound stpmax.", "Rounding errors prevent further progress, ftol/gtol tolerances may be too small.", "Non-negative differential."};
        private int _iter;
        int _maxfev = 20;
        double _maxStep = 1.0E10;
        double _minStep = 1.0E-10;

        public MoreThuente(GradientSolver gslvr, double[] betaStart) {
            this(gslvr, betaStart, gslvr.getGradient(betaStart), 0.1, 0.1, 0.01);
        }

        public MoreThuente(GradientSolver gslvr, double[] betaStart, GradientInfo ginfo) {
            this(gslvr, betaStart, ginfo, 0.1, 0.1, 1.0E-8);
        }

        public MoreThuente(GradientSolver gslvr, double[] betaStart, GradientInfo ginfo, double ftol, double gtol, double xtol) {
            this._gslvr = gslvr;
            this._beta = betaStart;
            this._ginfox = ginfo;
            if (ginfo._gradient == null) {
                throw new IllegalArgumentException("GradientInfo for MoreThuente line search solver must include gradient");
            }
            this._ftol = ftol;
            this._gtol = gtol;
            this._xtol = xtol;
        }

        @Override
        public MoreThuente setInitialStep(double t) {
            this._initialStep = t;
            return this;
        }

        @Override
        public int nfeval() {
            return this._iter;
        }

        @Override
        public double getObj() {
            return this.ginfo()._objVal;
        }

        @Override
        public double[] getX() {
            return this._beta;
        }

        private double nextStep(GradientInfo ginfo, double dg, double stp, double off) {
            double nextStep;
            double fvp = ginfo._objVal - stp * off;
            double dgp = dg - off;
            double fvx = this._fvx - this._stx * off;
            double fvy = this._fvy - this._sty * off;
            double stx = this._stx;
            double sty = this._sty;
            double dgx = this._dgx - off;
            double dgy = this._dgy - off;
            if (this._brackt && (stp <= Math.min(stx, sty) || stp >= Math.max(stx, sty)) || dgx * (stp - stx) >= 0.0) {
                return Double.NaN;
            }
            double theta = 3.0 * (fvx - fvp) / (stp - stx) + dgx + dgp;
            double s = Math.max(Math.max(Math.abs(theta), Math.abs(dgx)), Math.abs(dgp));
            double sInv = 1.0 / s;
            double ts = theta * sInv;
            double gamma = s * Math.sqrt(Math.max(0.0, ts * ts - dgx * sInv * (dgp * sInv)));
            int info = 0;
            if (fvp > fvx) {
                info = 1;
                if (stp < stx) {
                    gamma = -gamma;
                }
                this._bound = true;
                this._brackt = true;
                double p = gamma - dgx + theta;
                double q = gamma - dgx + gamma + dgp;
                double r = p / q;
                double stpc = stx + r * (stp - stx);
                double stpq = stx + dgx / ((fvx - fvp) / (stp - stx) + dgx) / 2.0 * (stp - stx);
                nextStep = Math.abs(stpc - stx) < Math.abs(stpq - stx) ? stpc : stpc + (stpq - stpc) / 2.0;
            } else if (dgp * dgx < 0.0) {
                info = 2;
                if (stp > stx) {
                    gamma = -gamma;
                }
                this._bound = false;
                this._brackt = true;
                double p = gamma - dgp + theta;
                double q = gamma - dgp + gamma + dgx;
                double r = p / q;
                double stpc = stp + r * (stx - stp);
                double stpq = stp + dgp / (dgp - dgx) * (stx - stp);
                nextStep = Math.abs(stpc - stp) > Math.abs(stpq - stp) ? stpc : stpq;
            } else if (Math.abs(dgp) < Math.abs(dgx)) {
                info = 3;
                if (stp > stx) {
                    gamma = -gamma;
                }
                this._bound = true;
                double p = gamma - dgp + theta;
                double q = gamma + dgx - dgp + gamma;
                double r = p / q;
                double stpc = r < 0.0 && gamma != 0.0 ? stp + r * (stx - stp) : (stp > stx ? this._stMax : this._stMin);
                double stpq = stp + dgp / (dgp - dgx) * (stx - stp);
                nextStep = this._brackt ? (Math.abs(stp - stpc) < Math.abs(stp - stpq) ? stpc : stpq) : (Math.abs(stp - stpc) > Math.abs(stp - stpq) ? stpc : stpq);
            } else {
                info = 4;
                this._bound = false;
                if (this._brackt) {
                    theta = 3.0 * (fvp - fvy) / (sty - stp) + dgy + dgp;
                    gamma = Math.sqrt(theta * theta - dgy * dgp);
                    if (stp > sty) {
                        gamma = -gamma;
                    }
                    double p = gamma - dgp + theta;
                    double q = gamma - dgp + gamma + dgy;
                    double r = p / q;
                    nextStep = stp + r * (sty - stp);
                } else {
                    double d = nextStep = stp > stx ? this._stMax : this._stMin;
                }
            }
            if (fvp > fvx) {
                this._sty = stp;
                this._fvy = ginfo._objVal;
                this._dgy = dg;
            } else {
                if (dgp * dgx < 0.0) {
                    this._sty = this._stx;
                    this._fvy = this._fvx;
                    this._dgy = this._dgx;
                }
                this._stx = stp;
                this._fvx = ginfo._objVal;
                this._dgx = dg;
                this._ginfox = ginfo;
            }
            if (nextStep > this._stMax) {
                nextStep = this._stMax;
            }
            if (nextStep < this._stMin) {
                nextStep = this._stMin;
            }
            if (this._brackt & this._bound) {
                nextStep = this._sty > this._stx ? Math.min(this._stx + 0.66 * (this._sty - this._stx), nextStep) : Math.max(this._stx + 0.66 * (this._sty - this._stx), nextStep);
            }
            return nextStep;
        }

        public String toString() {
            return "MoreThuente line search, iter = " + this._iter + ", status = " + this.messages[this._returnStatus] + ", step = " + this._stx + ", I = [" + this._stMin + ", " + this._stMax + "], grad = " + this._dgx + ", bestObj = " + this._fvx;
        }

        @Override
        public boolean evaluate(double[] direction) {
            boolean succ;
            double oldObjval = this._ginfox._objVal;
            double step = this._initialStep;
            this._bound = false;
            this._brackt = false;
            this._sty = 0.0;
            this._stx = 0.0;
            this._stMax = 0.0;
            this._stMin = 0.0;
            this._betGradient = null;
            this._bestPsiVal = Double.POSITIVE_INFINITY;
            this._bestStep = 0.0;
            double maxObj = this._ginfox._objVal - this._minRelativeImprovement * this._ginfox._objVal;
            double dgInit = ArrayUtils.innerProduct(this._ginfox._gradient, direction);
            double dgtest = dgInit * this._ftol;
            if (dgtest > 1.0E-4) {
                Log.warn("MoreThuente LS: got possitive differential " + dgtest);
            }
            if (dgtest >= 0.0) {
                this._returnStatus = 7;
                return false;
            }
            double[] beta = new double[this._beta.length];
            double width = this._maxStep - this._minStep;
            double oldWidth = 2.0 * width;
            boolean stage1 = true;
            this._fvx = this._fvy = this._ginfox._objVal;
            this._dgx = this._dgy = dgInit;
            this._iter = 0;
            while (true) {
                if (this._brackt) {
                    this._stMin = Math.min(this._stx, this._sty);
                    this._stMax = Math.max(this._stx, this._sty);
                } else {
                    this._stMin = this._stx;
                    this._stMax = step + this._xtrapf * (step - this._stx);
                }
                step = Math.min(step, this._maxStep);
                step = Math.max(step, this._minStep);
                double maxFval = oldObjval + step * dgtest;
                for (int i = 0; i < beta.length; ++i) {
                    beta[i] = this._beta[i] + step * direction[i];
                }
                GradientInfo newGinfo = this._gslvr.getGradient(beta);
                if (newGinfo._objVal < maxObj && (this._betGradient == null || newGinfo._objVal - maxFval < this._bestPsiVal)) {
                    this._bestPsiVal = newGinfo._objVal - maxFval;
                    this._betGradient = newGinfo;
                    this._bestStep = step;
                }
                ++this._iter;
                if (this._iter < this._maxfev && !Double.isNaN(step) && (Double.isNaN(newGinfo._objVal) || Double.isInfinite(newGinfo._objVal) || ArrayUtils.hasNaNsOrInfs(newGinfo._gradient))) {
                    this._brackt = true;
                    this._sty = step;
                    this._maxStep = step;
                    this._fvy = Double.POSITIVE_INFINITY;
                    this._dgy = Double.MAX_VALUE;
                    step *= 0.5;
                    continue;
                }
                double dgp = ArrayUtils.innerProduct(newGinfo._gradient, direction);
                if (Double.isNaN(step) || this._brackt && (step <= this._stMin || step >= this._stMax)) {
                    this._returnStatus = 6;
                    break;
                }
                if (step == this._maxStep && newGinfo._objVal <= maxFval & dgp <= dgtest) {
                    this._returnStatus = 5;
                    this._stx = step;
                    this._ginfox = newGinfo;
                    break;
                }
                if (step == this._minStep && newGinfo._objVal > maxFval | dgp >= dgtest) {
                    this._returnStatus = 4;
                    if (this._betGradient != null) {
                        this._stx = this._bestStep;
                        this._ginfox = this._betGradient;
                        break;
                    }
                    this._stx = step;
                    this._ginfox = newGinfo;
                    break;
                }
                if (this._iter >= this._maxfev) {
                    this._returnStatus = 3;
                    if (this._betGradient != null) {
                        this._stx = this._bestStep;
                        this._ginfox = this._betGradient;
                        break;
                    }
                    this._stx = step;
                    this._ginfox = newGinfo;
                    break;
                }
                if (this._brackt && this._stMax - this._stMin <= this._xtol * this._stMax) {
                    this._ginfox = newGinfo;
                    this._returnStatus = 2;
                    break;
                }
                if (newGinfo._objVal < maxFval && Math.abs(dgp) <= -this._gtol * dgInit) {
                    this._stx = step;
                    this._dgx = dgp;
                    this._fvx = newGinfo._objVal;
                    this._ginfox = newGinfo;
                    this._returnStatus = 1;
                    break;
                }
                stage1 = stage1 && (newGinfo._objVal > maxFval || dgp < dgtest);
                boolean useAugmentedFuntcion = stage1 && newGinfo._objVal <= this._fvx && newGinfo._objVal > maxFval;
                double off = useAugmentedFuntcion ? dgtest : 0.0;
                double nextStep = this.nextStep(newGinfo, dgp, step, off);
                if (this._brackt) {
                    if (Math.abs(this._sty - this._stx) >= 0.66 * oldWidth) {
                        nextStep = this._stx + 0.5 * (this._sty - this._stx);
                    }
                    oldWidth = width;
                    width = Math.abs(this._sty - this._stx);
                }
                step = nextStep;
            }
            boolean bl = succ = this._ginfox._objVal < oldObjval;
            if (succ) {
                for (int i = 0; i < beta.length; ++i) {
                    beta[i] = this._beta[i] + this._stx * direction[i];
                }
                this._beta = beta;
            }
            return succ;
        }

        @Override
        public double step() {
            return this._stx;
        }

        @Override
        public GradientInfo ginfo() {
            return this._ginfox;
        }
    }

    public static final class SimpleBacktrackingLS
    implements LineSearchSolver {
        private double[] _beta;
        final double _stepDec = 0.33;
        private double _step;
        private final GradientSolver _gslvr;
        private GradientInfo _ginfo;
        private double _objVal;
        final double _l1pen;
        int _maxfev = 20;
        double _minStep = 1.0E-4;

        public SimpleBacktrackingLS(GradientSolver gslvr, double[] betaStart, double l1pen) {
            this(gslvr, betaStart, l1pen, gslvr.getObjective(betaStart));
        }

        public SimpleBacktrackingLS(GradientSolver gslvr, double[] betaStart, double l1pen, GradientInfo ginfo) {
            this._gslvr = gslvr;
            this._beta = betaStart;
            this._ginfo = ginfo;
            this._l1pen = l1pen;
            this._objVal = this._ginfo._objVal + this._l1pen * ArrayUtils.l1norm(this._beta, true);
        }

        @Override
        public int nfeval() {
            return -1;
        }

        @Override
        public double getObj() {
            return this._objVal;
        }

        @Override
        public double[] getX() {
            return this._beta;
        }

        @Override
        public LineSearchSolver setInitialStep(double s) {
            return this;
        }

        @Override
        public boolean evaluate(double[] direction) {
            double step = 1.0;
            double minStep = 1.0;
            for (double d : direction) {
                if (!((d = Math.abs(1.0E-4 / d)) < minStep)) continue;
                minStep = d;
            }
            double[] newBeta = (double[])direction.clone();
            for (int i = 0; i < this._maxfev && step >= minStep; ++i, step *= 0.33) {
                GradientInfo ginfo = this._gslvr.getObjective(ArrayUtils.wadd(this._beta, direction, newBeta, step));
                double objVal = ginfo._objVal + this._l1pen * ArrayUtils.l1norm(newBeta, true);
                if (!(objVal < this._objVal)) continue;
                this._ginfo = ginfo;
                this._objVal = objVal;
                this._beta = newBeta;
                this._step = step;
                return true;
            }
            return false;
        }

        @Override
        public double step() {
            return this._step;
        }

        @Override
        public GradientInfo ginfo() {
            return this._ginfo;
        }

        public String toString() {
            return "";
        }
    }

    public static interface LineSearchSolver {
        public boolean evaluate(double[] var1);

        public double step();

        public GradientInfo ginfo();

        public LineSearchSolver setInitialStep(double var1);

        public int nfeval();

        public double getObj();

        public double[] getX();
    }

    public static interface GradientSolver {
        public GradientInfo getGradient(double[] var1);

        public GradientInfo getObjective(double[] var1);
    }

    public static class GradientInfo
    extends Iced {
        public double _objVal;
        public double[] _gradient;

        public GradientInfo(double objVal, double[] grad) {
            this._objVal = objVal;
            this._gradient = grad;
        }

        public boolean isValid() {
            if (Double.isNaN(this._objVal)) {
                return false;
            }
            return !ArrayUtils.hasNaNsOrInfs(this._gradient);
        }

        public String toString() {
            return " objVal = " + this._objVal + ", " + Arrays.toString(this._gradient);
        }
    }
}

