/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers;

import com.google.common.base.Function;
import org.deeplearning4j.exception.InvalidStepException;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.LineOptimizerMatrix;
import org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.deeplearning4j.optimize.solvers.VectorizedBackTrackLineSearch;
import org.deeplearning4j.util.OptimizerMatrix;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VectorizedNonZeroStoppingConjugateGradient
implements OptimizerMatrix {
    private static Logger logger = LoggerFactory.getLogger(VectorizedNonZeroStoppingConjugateGradient.class);
    boolean converged = false;
    OptimizableByGradientValueMatrix optimizable;
    VectorizedBackTrackLineSearch lineMaximizer;
    TrainingEvaluator eval;
    double initialStepSize = 1.0;
    double tolerance = 1.0E-5f;
    double gradientTolerance = 1.0E-5f;
    int maxIterations = 10000;
    private String myName = "";
    private IterationListener listener;
    double fp;
    double gg;
    double gam;
    double dgg;
    double step;
    double fret;
    INDArray xi;
    INDArray g;
    INDArray h;
    int j;
    int iterations;
    final double eps = 1.0E-10f;

    public VectorizedNonZeroStoppingConjugateGradient(OptimizableByGradientValueMatrix function, double initialStepSize) {
        this.initialStepSize = initialStepSize;
        this.optimizable = function;
        this.lineMaximizer = new VectorizedBackTrackLineSearch(function);
        this.lineMaximizer.setAbsTolx(this.tolerance);
    }

    public VectorizedNonZeroStoppingConjugateGradient(OptimizableByGradientValueMatrix function, IterationListener listener) {
        this(function, 0.01f);
        this.listener = listener;
    }

    public VectorizedNonZeroStoppingConjugateGradient(OptimizableByGradientValueMatrix function, double initialStepSize, IterationListener listener) {
        this(function, initialStepSize);
        this.listener = listener;
    }

    public VectorizedNonZeroStoppingConjugateGradient(OptimizableByGradientValueMatrix function, StepFunction stepFunction) {
        this(function, stepFunction, 0.01f);
    }

    public VectorizedNonZeroStoppingConjugateGradient(OptimizableByGradientValueMatrix function, StepFunction stepFunction, double initialStepSize) {
        this.initialStepSize = initialStepSize;
        this.optimizable = function;
        this.lineMaximizer = new VectorizedBackTrackLineSearch(function, stepFunction);
        this.lineMaximizer.setAbsTolx(this.tolerance);
    }

    public VectorizedNonZeroStoppingConjugateGradient(OptimizableByGradientValueMatrix function, StepFunction stepFunction, IterationListener listener) {
        this(function, stepFunction, 0.01f);
        this.listener = listener;
    }

    public VectorizedNonZeroStoppingConjugateGradient(OptimizableByGradientValueMatrix function, double initialStepSize, StepFunction stepFunction, IterationListener listener) {
        this(function, stepFunction, initialStepSize);
        this.listener = listener;
    }

    public VectorizedNonZeroStoppingConjugateGradient(OptimizableByGradientValueMatrix function) {
        this(function, 0.01f);
    }

    @Override
    public boolean isConverged() {
        return this.converged;
    }

    public void setLineMaximizer(LineOptimizerMatrix lineMaximizer) {
        this.lineMaximizer = (VectorizedBackTrackLineSearch)lineMaximizer;
    }

    public void setInitialStepSize(double initialStepSize) {
        this.initialStepSize = initialStepSize;
    }

    public double getInitialStepSize() {
        return this.initialStepSize;
    }

    public double getStepSize() {
        return this.step;
    }

    @Override
    public boolean optimize() {
        return this.optimize(this.maxIterations);
    }

    @Override
    public void setTolerance(double t) {
        this.tolerance = t;
    }

    @Override
    public boolean optimize(int numIterations) {
        this.myName = Thread.currentThread().getName();
        if (this.converged) {
            return true;
        }
        long last = System.currentTimeMillis();
        if (this.xi == null) {
            this.fp = this.optimizable.getValue();
            assert (!Double.isNaN(this.fp) && !Double.isInfinite(this.fp)) : "Function appears to be NaN or infinite, please check your parameters.";
            this.xi = this.optimizable.getValueGradient(0);
            this.g = this.xi.dup();
            this.h = this.xi.dup();
            this.iterations = 0;
        }
        long curr = 0L;
        for (int iterationCount = 0; iterationCount < numIterations; ++iterationCount) {
            curr = System.currentTimeMillis();
            logger.info(this.myName + " ConjugateGradient: At iteration " + this.iterations + ", cost = " + this.fp + " -" + (curr - last));
            last = curr;
            this.optimizable.setCurrentIteration(iterationCount);
            try {
                this.step = this.lineMaximizer.optimize(this.xi, iterationCount, this.step);
            }
            catch (InvalidStepException e) {
                logger.warn("Breaking: negative slope");
            }
            this.fret = this.optimizable.getValue();
            this.xi = this.optimizable.getValueGradient(iterationCount);
            if (0.0 < this.tolerance && 2.0 * Math.abs(this.fret - this.fp) <= this.tolerance * (Math.abs(this.fret) + Math.abs(this.fp) + (double)1.0E-10f)) {
                logger.info("ConjugateGradient converged: old value= " + this.fp + " new value= " + this.fret + " tolerance=" + this.tolerance);
                if (this.listener != null) {
                    this.listener.iterationDone(iterationCount);
                }
                this.converged = true;
                return true;
            }
            this.fp = this.fret;
            double twoNorm = this.xi.norm2(Integer.MAX_VALUE).getDouble(0);
            if (twoNorm < this.gradientTolerance) {
                logger.info("ConjugateGradient converged: gradient two norm " + twoNorm + ", less than " + this.gradientTolerance);
                this.converged = true;
                if (this.listener != null) {
                    this.listener.iterationDone(iterationCount);
                }
                return true;
            }
            this.gg = 0.0;
            this.dgg = 0.0;
            this.gg = Transforms.pow((INDArray)this.g, (Number)2).sum(Integer.MAX_VALUE).getDouble(0);
            this.dgg = this.xi.mul(this.xi.sub(this.g)).sum(Integer.MAX_VALUE).getDouble(0);
            this.gam = this.dgg / this.gg;
            this.g = this.xi.dup();
            this.h = this.xi.add(this.h.mul((Number)this.gam));
            BooleanIndexing.applyWhere((INDArray)this.h, (Condition)Conditions.isNan(), (Function)new Value((Number)Nd4j.EPS_THRESHOLD));
            LinAlgExceptions.assertValidNum((INDArray)this.h);
            if (Nd4j.getBlasWrapper().dot(this.xi, this.h) > 0.0) {
                this.xi = this.h.dup();
            } else {
                logger.warn("Reverting back to GA");
                this.h = this.xi.dup();
            }
            ++this.iterations;
            if (this.iterations > this.maxIterations) {
                logger.info("Passed max number of iterations");
                this.converged = true;
                if (this.listener != null) {
                    this.listener.iterationDone(iterationCount);
                }
                return true;
            }
            if (this.listener != null) {
                this.listener.iterationDone(iterationCount);
            }
            if (this.eval == null || !this.eval.shouldStop(this.iterations)) continue;
            return true;
        }
        return false;
    }

    @Override
    public void setTrainingEvaluator(TrainingEvaluator eval) {
        this.eval = eval;
    }

    public void reset() {
        this.xi = null;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    @Override
    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    public INDArray getH() {
        return this.h;
    }

    public void setH(INDArray h) {
        this.h = h;
    }

    public INDArray getG() {
        return this.g;
    }

    public void setG(INDArray g) {
        this.g = g;
    }

    public INDArray getXi() {
        return this.xi;
    }

    public void setXi(INDArray xi) {
        this.xi = xi;
    }

    public double getFret() {
        return this.fret;
    }

    public void setFret(double fret) {
        this.fret = fret;
    }

    public double getStep() {
        return this.step;
    }

    public void setStep(double step) {
        this.step = step;
    }

    public double getDgg() {
        return this.dgg;
    }

    public void setDgg(double dgg) {
        this.dgg = dgg;
    }

    public double getGam() {
        return this.gam;
    }

    public void setGam(double gam) {
        this.gam = gam;
    }

    public double getGg() {
        return this.gg;
    }

    public void setGg(double gg) {
        this.gg = gg;
    }

    public double getFp() {
        return this.fp;
    }

    public void setFp(double fp) {
        this.fp = fp;
    }
}

