/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.robotics.optimization;

import gnu.trove.list.array.TDoubleArrayList;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import us.ihmc.robotics.numericalMethods.GradientDescentModule;
import us.ihmc.robotics.numericalMethods.SingleQueryFunction;
import us.ihmc.robotics.optimization.CostFunction;
import us.ihmc.robotics.optimization.Optimizer;

public class WrappedGradientDescent
implements Optimizer {
    private GradientDescentModule gradientDescentModule;
    private CostFunction costFunction;
    private final DMatrixD1 vectorInputToCostFunction = new DMatrixRMaj();
    private final DMatrixD1 optimalInput = new DMatrixRMaj();
    private double stepSize = 10.0;
    private double learningRate = 0.9;

    @Override
    public void setCostFunction(CostFunction costFunction) {
        this.costFunction = costFunction;
    }

    private SingleQueryFunction createUnwrappedCostFunction(final CostFunction costFunction) {
        return new SingleQueryFunction(){

            @Override
            public double getQuery(TDoubleArrayList values) {
                WrappedGradientDescent.convertArrayToMatrix(WrappedGradientDescent.this.vectorInputToCostFunction, values);
                return costFunction.calculate(WrappedGradientDescent.this.vectorInputToCostFunction);
            }
        };
    }

    private static void convertArrayToMatrix(DMatrixD1 vector, TDoubleArrayList list) {
        vector.setData(list.toArray());
        vector.reshape(list.size(), 1);
    }

    private static void convertMatrixToArray(DMatrixD1 vector, TDoubleArrayList list) {
        list.reset();
        list.addAll(vector.data);
    }

    public void setInitialStepSize(double stepSize) {
        this.stepSize = stepSize;
    }

    public void setLearningRate(double learningRate) {
        this.learningRate = learningRate;
    }

    @Override
    public DMatrixD1 stepOneIteration() {
        return null;
    }

    @Override
    public DMatrixD1 optimize(DMatrixD1 initial) {
        TDoubleArrayList initialArray = new TDoubleArrayList();
        WrappedGradientDescent.convertMatrixToArray(initial, initialArray);
        this.gradientDescentModule = new GradientDescentModule(this.createUnwrappedCostFunction(this.costFunction), initialArray);
        this.gradientDescentModule.setStepSize(this.stepSize);
        this.gradientDescentModule.setReducingStepSizeRatio(1.0 / this.learningRate);
        this.gradientDescentModule.run();
        return this.getOptimalParameters();
    }

    @Override
    public DMatrixD1 getOptimalParameters() {
        WrappedGradientDescent.convertArrayToMatrix(this.optimalInput, this.gradientDescentModule.getOptimalInput());
        return this.optimalInput;
    }

    @Override
    public double getOptimumCost() {
        return this.gradientDescentModule.getOptimalQuery();
    }
}

