/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TerminationCondition;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.deeplearning4j.optimize.solvers.BaseOptimizer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StochasticHessianFree
extends BaseOptimizer {
    private static Logger logger = LoggerFactory.getLogger(StochasticHessianFree.class);
    boolean converged = false;
    TrainingEvaluator eval;
    double initialStepSize = 1.0;
    double tolerance = 1.0E-5f;
    double gradientTolerance = 0.0;
    private MultiLayerNetwork network;
    int maxIterations = 10000;
    private String myName = "";
    private static Logger log = LoggerFactory.getLogger(StochasticHessianFree.class);
    private INDArray ch;
    private INDArray gradient;
    private INDArray xi;
    private double pi = 0.5;
    private double decrease = 0.99f;
    private double boost = 1.0 / this.decrease;
    private double f = 1.0;
    private double score;
    private double step;

    public StochasticHessianFree(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Model model) {
        super(conf, stepFunction, iterationListeners, model);
        this.setup();
    }

    public StochasticHessianFree(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Collection<TerminationCondition> terminationConditions, Model model) {
        super(conf, stepFunction, iterationListeners, terminationConditions, model);
        this.setup();
    }

    void setup() {
        if (!(this.model instanceof MultiLayerNetwork)) {
            return;
        }
        this.network = (MultiLayerNetwork)this.model;
        this.xi = this.network.pack();
        this.ch = Nd4j.zeros((int)1, (int)this.xi.length());
    }

    public boolean isConverged() {
        return this.converged;
    }

    public Pair<List<Integer>, List<INDArray>> conjGradient(INDArray b, INDArray x0, INDArray preCon, int numIterations) {
        ArrayList<Integer> is = new ArrayList<Integer>();
        ArrayList<INDArray> xs = new ArrayList<INDArray>();
        INDArray r = this.network.getBackPropRGradient(x0).subi(b);
        INDArray y = r.div(preCon);
        double deltaNew = r.mul(y).sum(Integer.MAX_VALUE).getDouble(0);
        INDArray p = y.neg();
        INDArray x = x0;
        for (int iterationCount = 0; iterationCount < numIterations; ++iterationCount) {
            INDArray Ap = this.network.getBackPropRGradient(p);
            double pAp = Ap.mul(p).sum(Integer.MAX_VALUE).getDouble(0);
            if (pAp < 0.0) {
                log.info("Negative slope: " + pAp + " breaking");
            }
            double alpha = deltaNew / pAp;
            x.addi(p.mul((Number)alpha));
            INDArray rNew = r.addi(Ap.mul((Number)alpha));
            INDArray yNew = rNew.div(preCon);
            double deltaOld = deltaNew;
            deltaNew = rNew.mul(yNew).sum(Integer.MAX_VALUE).getDouble(0);
            double beta = deltaNew / deltaOld;
            p = yNew.neg().addi(p.mul((Number)beta));
            r = rNew;
            is.add(iterationCount);
            xs.add(x.dup());
        }
        return new Pair<List<Integer>, List<INDArray>>(is, xs);
    }

    private Triple<INDArray, List<INDArray>, INDArray> runConjugateGradient(INDArray preCon, int numIterations) {
        Pair<List<Integer>, List<INDArray>> cg = this.conjGradient(this.gradient, this.ch, preCon, numIterations);
        this.ch = cg.getSecond().get(cg.getSecond().size() - 1);
        return new Triple<INDArray, List<INDArray>, INDArray>(this.ch, cg.getSecond(), this.ch);
    }

    public double lineSearch(double newScore, INDArray params, INDArray p) {
        int j;
        double rate = 1.0;
        double c = 0.01f;
        int numSearches = 60;
        for (j = 0; j < numSearches; ++j) {
            if (10 % numSearches == 0) {
                log.info("Iteration " + j + " on line search with current rate of " + rate);
            }
            if (newScore <= this.gradient.mul(p).muli((Number)(this.score + c * rate)).sum(Integer.MAX_VALUE).getDouble(0)) break;
            newScore = this.network.score(params.add(p.mul((Number)(rate *= (double)0.8f))));
        }
        if (j == numSearches) {
            rate = 0.0;
            log.info("Went too far...reverting rate to 0");
        }
        return rate;
    }

    public Pair<INDArray, Double> cgBackTrack(List<INDArray> chs, INDArray p) {
        int i;
        INDArray params = this.network.params();
        double score = this.network.score(p.add(params));
        double currMin = this.network.score();
        for (i = chs.size() - 2; i > 0; --i) {
            double score2 = this.network.score(params.add(chs.get(i)));
            if (!(score2 < score) && !(score2 < currMin)) continue;
            ++i;
            score = score2;
            break;
        }
        if (i < 0) {
            i = 0;
        }
        return new Pair<INDArray, Double>(chs.get(i), score);
    }

    @Override
    public boolean optimize() {
        if (!(this.model instanceof MultiLayerNetwork)) {
            return true;
        }
        this.myName = Thread.currentThread().getName();
        if (this.converged) {
            return true;
        }
        this.score = this.network.score();
        this.xi = this.network.params();
        for (int i = 0; i < this.conf.getNumIterations(); ++i) {
            Pair<INDArray, INDArray> backPropGradient = this.network.getBackPropGradient2();
            this.gradient = backPropGradient.getFirst().neg();
            INDArray preCon = backPropGradient.getSecond();
            if (this.ch == null) {
                this.setup();
            }
            this.ch.muli((Number)this.pi);
            Triple<INDArray, List<INDArray>, INDArray> cg = this.runConjugateGradient(preCon, this.conf.getNumIterations());
            INDArray p = cg.getFirst();
            Pair<INDArray, Double> cgBackTrack = this.cgBackTrack(cg.getSecond(), p);
            p = cgBackTrack.getFirst();
            double rho = this.network.reductionRatio(cgBackTrack.getFirst(), this.network.score(), cgBackTrack.getSecond(), this.gradient);
            double newScore = this.network.score(cgBackTrack.getFirst());
            this.step = this.lineSearch(newScore, this.gradient, p);
            this.network.dampingUpdate(rho, this.boost, this.decrease);
            INDArray proposedUpdate = this.xi.add(p.mul((Number)(this.f * this.step)));
            this.network.setParameters(proposedUpdate);
            log.info("Score at iteration " + i + " was " + newScore);
        }
        return true;
    }
}

