/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.variational.optimizer.gradient;

import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunction;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunctionGradient;
import io.improbable.keanu.algorithms.variational.optimizer.OptimizedResult;
import io.improbable.keanu.algorithms.variational.optimizer.Optimizer;
import io.improbable.keanu.algorithms.variational.optimizer.gradient.ApacheFitnessFunctionAdapter;
import io.improbable.keanu.algorithms.variational.optimizer.gradient.ApacheFitnessFunctionGradientAdapter;
import io.improbable.keanu.algorithms.variational.optimizer.gradient.GradientOptimizationAlgorithm;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleValueChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient;
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer;

public class ConjugateGradient
implements GradientOptimizationAlgorithm {
    private final int maxEvaluations;
    private final double relativeThreshold;
    private final double absoluteThreshold;
    private UpdateFormula updateFormula;

    public static ConjugateGradientBuilder builder() {
        return new ConjugateGradientBuilder();
    }

    @Override
    public OptimizedResult optimize(List<? extends Variable> latentVariables, FitnessFunction fitnessFunction, FitnessFunctionGradient fitnessFunctionGradient) {
        ObjectiveFunction fitness = new ObjectiveFunction((MultivariateFunction)new ApacheFitnessFunctionAdapter(fitnessFunction, latentVariables));
        ObjectiveFunctionGradient gradient = new ObjectiveFunctionGradient((MultivariateVectorFunction)new ApacheFitnessFunctionGradientAdapter(fitnessFunctionGradient, latentVariables));
        double[] startingPoint = Optimizer.convertToArrayPoint(Optimizer.getAsDoubleTensors(latentVariables));
        NonLinearConjugateGradientOptimizer optimizer = new NonLinearConjugateGradientOptimizer(this.updateFormula.apacheMapping, (ConvergenceChecker)new SimpleValueChecker(this.relativeThreshold, this.absoluteThreshold));
        PointValuePair pointValuePair = optimizer.optimize(new OptimizationData[]{new MaxEval(this.maxEvaluations), fitness, gradient, GoalType.MAXIMIZE, new InitialGuess(startingPoint)});
        Map<VariableReference, DoubleTensor> optimizedValues = Optimizer.convertFromPoint(pointValuePair.getPoint(), latentVariables);
        return new OptimizedResult(optimizedValues, (Double)pointValuePair.getValue());
    }

    public ConjugateGradient(int maxEvaluations, double relativeThreshold, double absoluteThreshold, UpdateFormula updateFormula) {
        this.maxEvaluations = maxEvaluations;
        this.relativeThreshold = relativeThreshold;
        this.absoluteThreshold = absoluteThreshold;
        this.updateFormula = updateFormula;
    }

    public static class ConjugateGradientBuilder {
        private int maxEvaluations = Integer.MAX_VALUE;
        private double relativeThreshold = 1.0E-8;
        private double absoluteThreshold = 1.0E-8;
        private UpdateFormula updateFormula = UpdateFormula.POLAK_RIBIERE;

        public ConjugateGradientBuilder maxEvaluations(int maxEvaluations) {
            if (maxEvaluations <= 0) {
                throw new NotStrictlyPositiveException((Number)maxEvaluations);
            }
            this.maxEvaluations = maxEvaluations;
            return this;
        }

        public ConjugateGradientBuilder relativeThreshold(double relativeThreshold) {
            if (relativeThreshold <= 0.0) {
                throw new NotStrictlyPositiveException((Number)relativeThreshold);
            }
            this.relativeThreshold = relativeThreshold;
            return this;
        }

        public ConjugateGradientBuilder absoluteThreshold(double absoluteThreshold) {
            if (absoluteThreshold <= 0.0) {
                throw new NotStrictlyPositiveException((Number)absoluteThreshold);
            }
            this.absoluteThreshold = absoluteThreshold;
            return this;
        }

        public ConjugateGradientBuilder updateFormula(UpdateFormula updateFormula) {
            this.updateFormula = updateFormula;
            return this;
        }

        public ConjugateGradient build() {
            return new ConjugateGradient(this.maxEvaluations, this.relativeThreshold, this.absoluteThreshold, this.updateFormula);
        }

        public String toString() {
            return "ConjugateGradient.ConjugateGradientBuilder(maxEvaluations=" + this.maxEvaluations + ", relativeThreshold=" + this.relativeThreshold + ", absoluteThreshold=" + this.absoluteThreshold + ", updateFormula=" + (Object)((Object)this.updateFormula) + ")";
        }
    }

    public static enum UpdateFormula {
        POLAK_RIBIERE(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE),
        FLETCHER_REEVES(NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES);

        NonLinearConjugateGradientOptimizer.Formula apacheMapping;

        private UpdateFormula(NonLinearConjugateGradientOptimizer.Formula apacheMapping) {
            this.apacheMapping = apacheMapping;
        }
    }
}

