package io.improbable.keanu.algorithms.variational.optimizer.gradient;

import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunction;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunctionGradient;
import io.improbable.keanu.algorithms.variational.optimizer.OptimizedResult;
import io.improbable.keanu.algorithms.variational.optimizer.Optimizer;
import java.util.List;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleValueChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient;
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer;

/* loaded from: input_file:io/improbable/keanu/algorithms/variational/optimizer/gradient/ConjugateGradient.class */
public class ConjugateGradient implements GradientOptimizationAlgorithm {
    private final int maxEvaluations;
    private final double relativeThreshold;
    private final double absoluteThreshold;
    private UpdateFormula updateFormula;

    /* loaded from: input_file:io/improbable/keanu/algorithms/variational/optimizer/gradient/ConjugateGradient$ConjugateGradientBuilder.class */
    public static class ConjugateGradientBuilder {
        private int maxEvaluations = Integer.MAX_VALUE;
        private double relativeThreshold = 1.0E-8d;
        private double absoluteThreshold = 1.0E-8d;
        private UpdateFormula updateFormula = UpdateFormula.POLAK_RIBIERE;

        public ConjugateGradientBuilder maxEvaluations(int i) {
            if (i <= 0) {
                throw new NotStrictlyPositiveException(Integer.valueOf(i));
            }
            this.maxEvaluations = i;
            return this;
        }

        public ConjugateGradientBuilder relativeThreshold(double d) {
            if (d <= 0.0d) {
                throw new NotStrictlyPositiveException(Double.valueOf(d));
            }
            this.relativeThreshold = d;
            return this;
        }

        public ConjugateGradientBuilder absoluteThreshold(double d) {
            if (d <= 0.0d) {
                throw new NotStrictlyPositiveException(Double.valueOf(d));
            }
            this.absoluteThreshold = d;
            return this;
        }

        public ConjugateGradientBuilder updateFormula(UpdateFormula updateFormula) {
            this.updateFormula = updateFormula;
            return this;
        }

        public ConjugateGradient build() {
            return new ConjugateGradient(this.maxEvaluations, this.relativeThreshold, this.absoluteThreshold, this.updateFormula);
        }

        public String toString() {
            return "ConjugateGradient.ConjugateGradientBuilder(maxEvaluations=" + this.maxEvaluations + ", relativeThreshold=" + this.relativeThreshold + ", absoluteThreshold=" + this.absoluteThreshold + ", updateFormula=" + this.updateFormula + ")";
        }
    }

    /* loaded from: input_file:io/improbable/keanu/algorithms/variational/optimizer/gradient/ConjugateGradient$UpdateFormula.class */
    public enum UpdateFormula {
        POLAK_RIBIERE(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE),
        FLETCHER_REEVES(NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES);

        NonLinearConjugateGradientOptimizer.Formula apacheMapping;

        UpdateFormula(NonLinearConjugateGradientOptimizer.Formula formula) {
            this.apacheMapping = formula;
        }
    }

    public static ConjugateGradientBuilder builder() {
        return new ConjugateGradientBuilder();
    }

    @Override // io.improbable.keanu.algorithms.variational.optimizer.gradient.GradientOptimizationAlgorithm
    public OptimizedResult optimize(List<? extends Variable> list, FitnessFunction fitnessFunction, FitnessFunctionGradient fitnessFunctionGradient) {
        PointValuePair optimize = new NonLinearConjugateGradientOptimizer(this.updateFormula.apacheMapping, new SimpleValueChecker(this.relativeThreshold, this.absoluteThreshold)).optimize(new OptimizationData[]{new MaxEval(this.maxEvaluations), new ObjectiveFunction(new ApacheFitnessFunctionAdapter(fitnessFunction, list)), new ObjectiveFunctionGradient(new ApacheFitnessFunctionGradientAdapter(fitnessFunctionGradient, list)), GoalType.MAXIMIZE, new InitialGuess(Optimizer.convertToArrayPoint(Optimizer.getAsDoubleTensors(list)))});
        return new OptimizedResult(Optimizer.convertFromPoint(optimize.getPoint(), list), ((Double) optimize.getValue()).doubleValue());
    }

    public ConjugateGradient(int i, double d, double d2, UpdateFormula updateFormula) {
        this.maxEvaluations = i;
        this.relativeThreshold = d;
        this.absoluteThreshold = d2;
        this.updateFormula = updateFormula;
    }
}
