/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers;

import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TerminationCondition;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.BaseOptimizer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

public class LBFGS
extends BaseOptimizer {
    private static final long serialVersionUID = 9148732140255034888L;
    private int m = 4;

    public LBFGS(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<TrainingListener> trainingListeners, Model model) {
        super(conf, stepFunction, trainingListeners, model);
    }

    public LBFGS(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<TrainingListener> trainingListeners, Collection<TerminationCondition> terminationConditions, Model model) {
        super(conf, stepFunction, trainingListeners, terminationConditions, model);
    }

    @Override
    public void setupSearchState(Pair<Gradient, Double> pair) {
        super.setupSearchState(pair);
        INDArray params = (INDArray)this.searchState.get("params");
        this.searchState.put("s", new LinkedList());
        this.searchState.put("y", new LinkedList());
        this.searchState.put("rho", new LinkedList());
        this.searchState.put("oldparams", params.dup());
    }

    @Override
    public void preProcessLine() {
        if (!this.searchState.containsKey("searchDirection")) {
            this.searchState.put("searchDirection", ((INDArray)this.searchState.get("g")).dup());
        }
    }

    @Override
    public void postStep(INDArray gradient) {
        INDArray yCurrent;
        INDArray sCurrent;
        INDArray previousParameters = (INDArray)this.searchState.get("oldparams");
        INDArray parameters = this.model.params();
        INDArray previousGradient = (INDArray)this.searchState.get("g");
        LinkedList rho = (LinkedList)this.searchState.get("rho");
        LinkedList s = (LinkedList)this.searchState.get("s");
        LinkedList y = (LinkedList)this.searchState.get("y");
        double sy = Nd4j.getBlasWrapper().dot(previousParameters, previousGradient) + Nd4j.EPS_THRESHOLD;
        double yy = Nd4j.getBlasWrapper().dot(previousGradient, previousGradient) + Nd4j.EPS_THRESHOLD;
        if (s.size() >= this.m) {
            sCurrent = (INDArray)s.removeLast();
            yCurrent = (INDArray)y.removeLast();
            rho.removeLast();
            sCurrent.assign(parameters).subi(previousParameters);
            yCurrent.assign(gradient).subi(previousGradient);
        } else {
            sCurrent = parameters.sub(previousParameters);
            yCurrent = gradient.sub(previousGradient);
        }
        rho.addFirst(1.0 / sy);
        s.addFirst(sCurrent);
        y.addFirst(yCurrent);
        if (s.size() != y.size()) {
            throw new IllegalStateException("Gradient and parameter sizes are not equal");
        }
        int numVectors = Math.min(this.m, s.size());
        double[] alpha = new double[numVectors];
        Iterator sIter = s.iterator();
        Iterator yIter = y.iterator();
        Iterator rhoIter = rho.iterator();
        INDArray searchDir = (INDArray)this.searchState.get("searchDirection");
        searchDir.assign(gradient);
        for (int i = 0; i < numVectors; ++i) {
            INDArray si = (INDArray)sIter.next();
            INDArray yi = (INDArray)yIter.next();
            double rhoi = (Double)rhoIter.next();
            if (si.length() != searchDir.length()) {
                throw new IllegalStateException("Gradients and parameters length not equal");
            }
            alpha[i] = rhoi * Nd4j.getBlasWrapper().dot(si, searchDir);
            Nd4j.getBlasWrapper().level1().axpy(searchDir.length(), -alpha[i], yi, searchDir);
        }
        double gamma = sy / yy;
        searchDir.muli((Number)gamma);
        sIter = s.descendingIterator();
        yIter = y.descendingIterator();
        rhoIter = rho.descendingIterator();
        for (int i = 0; i < numVectors; ++i) {
            INDArray si = (INDArray)sIter.next();
            INDArray yi = (INDArray)yIter.next();
            double rhoi = (Double)rhoIter.next();
            double beta = rhoi * Nd4j.getBlasWrapper().dot(yi, searchDir);
            Nd4j.getBlasWrapper().level1().axpy(gradient.length(), alpha[i] - beta, si, searchDir);
        }
        previousParameters.assign(parameters);
        previousGradient.assign(gradient);
    }
}

