package org.allenai.ml.optimize;

import com.gs.collections.api.tuple.Pair;
import com.gs.collections.impl.tuple.Tuples;
import java.util.ArrayList;
import java.util.List;
import org.allenai.ml.linalg.Vector;

@FunctionalInterface
/* loaded from: input_file:org/allenai/ml/optimize/QuasiNewton.class */
public interface QuasiNewton {
    Vector implictMultiply(Vector vector);

    default void update(Vector vector, Vector vector2) {
    }

    static QuasiNewton gradientDescent() {
        return vector -> {
            return vector;
        };
    }

    static QuasiNewton lbfgs(final int i) {
        return new QuasiNewton() { // from class: org.allenai.ml.optimize.QuasiNewton.1
            private final List<Pair<Vector, Vector>> history = new ArrayList();
            private static final double EPS = 1.0E-200d;
            static final /* synthetic */ boolean $assertionsDisabled;

            private double initialScale() {
                if (this.history.isEmpty()) {
                    return 1.0d;
                }
                Vector vector = (Vector) this.history.get(0).getOne();
                Vector vector2 = (Vector) this.history.get(0).getTwo();
                double dotProduct = vector2.dotProduct(vector);
                double l2NormSquared = vector2.l2NormSquared();
                if ($assertionsDisabled || l2NormSquared > 0.0d) {
                    return dotProduct / l2NormSquared;
                }
                throw new AssertionError("Shouldn't have gotten a 0 diff between successive gradients");
            }

            @Override // org.allenai.ml.optimize.QuasiNewton
            public Vector implictMultiply(Vector vector) {
                double[] dArr = new double[this.history.size()];
                double[] dArr2 = new double[this.history.size()];
                Vector copy = vector.copy();
                for (int size = this.history.size() - 1; size >= 0; size--) {
                    Vector vector2 = (Vector) this.history.get(size).getOne();
                    Vector vector3 = (Vector) this.history.get(size).getTwo();
                    dArr[size] = vector2.dotProduct(vector3);
                    if (!$assertionsDisabled && dArr[size] == 0.0d) {
                        throw new AssertionError("Input diff and derivative diff can't be orthogonal by construction");
                    }
                    dArr2[size] = vector2.dotProduct(copy) / dArr[size];
                    copy = copy.add(-dArr2[size], vector3);
                }
                copy.scaleInPlace(initialScale());
                Vector vector4 = copy;
                for (int i2 = 0; i2 < this.history.size(); i2++) {
                    vector4 = vector4.add(dArr2[i2] - (((Vector) this.history.get(i2).getTwo()).dotProduct(vector4) / dArr[i2]), (Vector) this.history.get(i2).getOne());
                }
                return vector4;
            }

            @Override // org.allenai.ml.optimize.QuasiNewton
            public void update(Vector vector, Vector vector2) {
                if (vector.l2NormSquared() < EPS || vector2.l2NormSquared() < EPS) {
                    throw new IllegalArgumentException("Too small a diff between successive input or gradient.Should have already converged already");
                }
                this.history.add(0, Tuples.pair(vector, vector2));
                while (this.history.size() > i) {
                    this.history.remove(this.history.size() - 1);
                }
            }

            static {
                $assertionsDisabled = !QuasiNewton.class.desiredAssertionStatus();
            }
        };
    }
}
