/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers;

import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.exception.InvalidStepException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.LineOptimizer;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.stepfunctions.DefaultStepFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarSetValue;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BackTrackLineSearch
implements LineOptimizer {
    private static final Logger logger = LoggerFactory.getLogger((String)BackTrackLineSearch.class.getName());
    private Model function;
    private StepFunction stepFunction = new DefaultStepFunction();
    private ConvexOptimizer optimizer;
    private int maxIterations = 100;
    double stpmax = 100.0;
    private double relTolx = 1.0E-10f;
    private double absTolx = 1.0E-4f;
    final double ALF = 1.0E-4f;

    public BackTrackLineSearch(Model function, StepFunction stepFunction, ConvexOptimizer optimizer) {
        this.function = function;
        this.stepFunction = stepFunction;
        this.optimizer = optimizer;
    }

    public BackTrackLineSearch(Model optimizable, ConvexOptimizer optimizer) {
        this(optimizable, new DefaultStepFunction(), optimizer);
    }

    public void setStpmax(double stpmax) {
        this.stpmax = stpmax;
    }

    public double getStpmax() {
        return this.stpmax;
    }

    public void setRelTolx(double tolx) {
        this.relTolx = tolx;
    }

    public void setAbsTolx(double tolx) {
        this.absTolx = tolx;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    @Override
    public double optimize(double initialStep, INDArray x, INDArray line) throws InvalidStepException {
        double sum;
        double fold;
        INDArray oldParameters = x.dup();
        INDArray g = line.dup();
        double alam2 = 0.0;
        double f2 = fold = this.optimizer.score();
        if (logger.isDebugEnabled()) {
            logger.trace("ENTERING BACKTRACK\n");
            logger.trace("Entering BackTrackLinnSearch, value = " + fold + ",\ndirection.oneNorm:" + g.norm1(Integer.MAX_VALUE) + "  direction.infNorm:" + FastMath.max((double)Double.NEGATIVE_INFINITY, (double)Transforms.abs((INDArray)g).max(Integer.MAX_VALUE).getDouble(0)));
        }
        if ((sum = line.norm2(Integer.MAX_VALUE).getDouble(0)) > this.stpmax) {
            logger.warn("attempted step too big. scaling: sum= " + sum + ", stpmax= " + this.stpmax);
            line.muli((Number)(this.stpmax / sum));
        }
        double slope = Nd4j.getBlasWrapper().dot(g, line);
        logger.debug("slope = " + slope);
        if (slope < 0.0) {
            throw new InvalidStepException("Slope = " + slope + " is negative");
        }
        if (slope == 0.0) {
            throw new InvalidStepException("Slope = " + slope + " is zero");
        }
        INDArray maxOldParams = Transforms.abs((INDArray)oldParameters);
        Nd4j.getExecutioner().exec((Op)new ScalarSetValue(maxOldParams, (Number)1));
        INDArray testMatrix = Transforms.abs((INDArray)line).divi(maxOldParams);
        double test = testMatrix.max(Integer.MAX_VALUE).getDouble(0);
        double alamin = this.relTolx / test;
        double alam = 1.0;
        double oldAlam = 0.0;
        for (int iteration = 0; iteration < this.maxIterations; ++iteration) {
            double tmplam;
            double f;
            logger.trace("BackTrack loop iteration " + iteration + " : alam=" + alam + " oldAlam=" + oldAlam);
            logger.trace("before step, x.1norm: " + x.norm1(Integer.MAX_VALUE) + "\nalam: " + alam + "\noldAlam: " + oldAlam);
            assert (alam != oldAlam) : "alam == oldAlam";
            if (this.stepFunction == null) {
                this.stepFunction = new DefaultStepFunction();
            }
            this.stepFunction.step(x, line, new Object[]{alam, oldAlam});
            if (logger.isDebugEnabled()) {
                double norm1 = x.norm1(Integer.MAX_VALUE).getDouble(0);
                logger.debug("after step, x.1norm: " + norm1);
            }
            if (alam < alamin || Nd4j.getExecutioner().execAndReturn((TransformOp)new Eps(oldParameters.linearView(), x.linearView(), x.linearView().dup(), x.length())).sum(Integer.MAX_VALUE).getDouble(0) == (double)x.length()) {
                this.function.setParams(oldParameters);
                this.function.setScore();
                f = this.function.score();
                logger.trace("EXITING BACKTRACK: Jump too small (alamin = " + alamin + "). Exiting and using xold. Value = " + f);
                return 0.0;
            }
            this.function.setParams(x);
            oldAlam = alam;
            this.function.setScore();
            f = this.function.score();
            logger.debug("value = " + f);
            if (f >= fold + (double)1.0E-4f * alam * slope) {
                logger.debug("EXITING BACKTRACK: value=" + f);
                if (f < fold) {
                    throw new IllegalStateException("Function did not increase: f = " + f + " < " + fold + " = fold");
                }
                return alam;
            }
            if (Double.isInfinite(f) || Double.isInfinite(f2)) {
                logger.warn("Value is infinite after jump " + oldAlam + ". f=" + f + ", f2=" + f2 + ". Scaling back step size...");
                tmplam = 0.2 * alam;
                if (alam < alamin) {
                    this.function.setParams(oldParameters);
                    this.function.setScore();
                    f = this.function.score();
                    logger.warn("EXITING BACKTRACK: Jump too small. Exiting and using xold. Value=" + f);
                    return 0.0;
                }
            } else if (alam == 1.0) {
                tmplam = -slope / (2.0 * (f - fold - slope));
            } else {
                double disc;
                double rhs1 = f - fold - alam * slope;
                double rhs2 = f2 - fold - alam2 * slope;
                if (alam - alam2 == 0.0) {
                    throw new IllegalStateException("FAILURE: dividing by alam-alam2. alam=" + alam);
                }
                double a = (rhs1 / FastMath.pow((double)alam, (int)2) - rhs2 / FastMath.pow((double)alam2, (int)2)) / (alam - alam2);
                double b = (-alam2 * rhs1 / (alam * alam) + alam * rhs2 / (alam2 * alam2)) / (alam - alam2);
                tmplam = a == 0.0 ? -slope / (2.0 * b) : ((disc = b * b - 3.0 * a * slope) < 0.0 ? 0.5 * alam : (b <= 0.0 ? (-b + FastMath.sqrt((double)disc)) / (3.0 * a) : -slope / (b + FastMath.sqrt((double)disc))));
                if (tmplam > 0.5 * alam) {
                    tmplam = 0.5 * alam;
                }
            }
            alam2 = alam;
            f2 = f;
            logger.debug("tmplam:" + tmplam);
            alam = Math.max(tmplam, (double)0.1f * alam);
        }
        return 0.0;
    }
}

