/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.variational.optimizer.nongradient;

import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunction;
import io.improbable.keanu.algorithms.variational.optimizer.OptimizedResult;
import io.improbable.keanu.algorithms.variational.optimizer.Optimizer;
import io.improbable.keanu.algorithms.variational.optimizer.gradient.ApacheFitnessFunctionAdapter;
import io.improbable.keanu.algorithms.variational.optimizer.nongradient.ApacheMathSimpleBoundsCalculator;
import io.improbable.keanu.algorithms.variational.optimizer.nongradient.NonGradientOptimizationAlgorithm;
import io.improbable.keanu.algorithms.variational.optimizer.nongradient.OptimizerBounds;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleBounds;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer;
import org.nd4j.base.Preconditions;

public class BOBYQA
implements NonGradientOptimizationAlgorithm {
    private final int maxEvaluations;
    private final double boundsRange;
    private final OptimizerBounds optimizerBounds;
    private final double initialTrustRegionRadius;
    private final double stoppingTrustRegionRadius;

    public static BOBYQABuilder builder() {
        return new BOBYQABuilder();
    }

    @Override
    public OptimizedResult optimize(List<? extends Variable> latentVariables, FitnessFunction fitnessFunction) {
        List<long[]> shapes = latentVariables.stream().map(Variable::getShape).collect(Collectors.toList());
        this.checkThereIsMoreThanOneDimension(shapes);
        BOBYQAOptimizer optimizer = new BOBYQAOptimizer(this.getNumInterpolationPoints(shapes), this.initialTrustRegionRadius, this.stoppingTrustRegionRadius);
        ObjectiveFunction fitness = new ObjectiveFunction((MultivariateFunction)new ApacheFitnessFunctionAdapter(fitnessFunction, latentVariables));
        double[] startPoint = Optimizer.convertToArrayPoint(Optimizer.getAsDoubleTensors(latentVariables));
        ApacheMathSimpleBoundsCalculator boundsCalculator = new ApacheMathSimpleBoundsCalculator(this.boundsRange, this.optimizerBounds);
        SimpleBounds bounds = boundsCalculator.getBounds(latentVariables, startPoint);
        PointValuePair pointValuePair = optimizer.optimize(new OptimizationData[]{new MaxEval(this.maxEvaluations), fitness, bounds, GoalType.MAXIMIZE, new InitialGuess(startPoint)});
        Map<VariableReference, DoubleTensor> optimizedValues = Optimizer.convertFromPoint(pointValuePair.getPoint(), latentVariables);
        return new OptimizedResult(optimizedValues, (Double)pointValuePair.getValue());
    }

    private void checkThereIsMoreThanOneDimension(List<long[]> latentVariablesShapes) {
        int totalDimensions = 0;
        for (long[] shape : latentVariablesShapes) {
            totalDimensions = (int)((long)totalDimensions + TensorShape.getLength(shape));
        }
        Preconditions.checkArgument((totalDimensions > 1 ? 1 : 0) != 0, (String)("BOBYQA requires at least two dimensions to perform optimisation. You provided: " + totalDimensions + " dimension."));
    }

    private int getNumInterpolationPoints(List<long[]> latentVariableShapes) {
        return (int)(2L * Optimizer.totalNumberOfLatentDimensions(latentVariableShapes) + 1L);
    }

    public BOBYQA(int maxEvaluations, double boundsRange, OptimizerBounds optimizerBounds, double initialTrustRegionRadius, double stoppingTrustRegionRadius) {
        this.maxEvaluations = maxEvaluations;
        this.boundsRange = boundsRange;
        this.optimizerBounds = optimizerBounds;
        this.initialTrustRegionRadius = initialTrustRegionRadius;
        this.stoppingTrustRegionRadius = stoppingTrustRegionRadius;
    }

    public static class BOBYQABuilder {
        private int maxEvaluations = Integer.MAX_VALUE;
        private double boundsRange = Double.POSITIVE_INFINITY;
        private OptimizerBounds optimizerBounds = new OptimizerBounds();
        private double initialTrustRegionRadius = 10.0;
        private double stoppingTrustRegionRadius = 1.0E-8;

        public BOBYQABuilder maxEvaluations(int maxEvaluations) {
            this.maxEvaluations = maxEvaluations;
            return this;
        }

        public BOBYQABuilder boundsRange(double boundsRange) {
            this.boundsRange = boundsRange;
            return this;
        }

        public BOBYQABuilder optimizerBounds(OptimizerBounds optimizerBounds) {
            this.optimizerBounds = optimizerBounds;
            return this;
        }

        public BOBYQABuilder initialTrustRegionRadius(double initialTrustRegionRadius) {
            this.initialTrustRegionRadius = initialTrustRegionRadius;
            return this;
        }

        public BOBYQABuilder stoppingTrustRegionRadius(double stoppingTrustRegionRadius) {
            this.stoppingTrustRegionRadius = stoppingTrustRegionRadius;
            return this;
        }

        public BOBYQA build() {
            return new BOBYQA(this.maxEvaluations, this.boundsRange, this.optimizerBounds, this.initialTrustRegionRadius, this.stoppingTrustRegionRadius);
        }

        public String toString() {
            return "BOBYQA.BOBYQABuilder(maxEvaluations=" + this.maxEvaluations + ", boundsRange=" + this.boundsRange + ", optimizerBounds=" + this.optimizerBounds + ", initialTrustRegionRadius=" + this.initialTrustRegionRadius + ", stoppingTrustRegionRadius=" + this.stoppingTrustRegionRadius + ")";
        }
    }
}

