package io.improbable.keanu.algorithms.variational.optimizer.gradient;

import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.variational.optimizer.ConvergenceChecker;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunction;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunctionGradient;
import io.improbable.keanu.algorithms.variational.optimizer.OptimizedResult;
import io.improbable.keanu.algorithms.variational.optimizer.RelativeConvergenceChecker;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;

/* loaded from: input_file:io/improbable/keanu/algorithms/variational/optimizer/gradient/Adam.class */
public class Adam implements GradientOptimizationAlgorithm {
    private final ConvergenceChecker convergenceChecker;
    private final int maxEvaluations;
    private final double alpha;
    private final double beta1;
    private final double beta2;
    private final double epsilon;
    private AdamStatistics statistics;

    /* loaded from: input_file:io/improbable/keanu/algorithms/variational/optimizer/gradient/Adam$AdamBuilder.class */
    public static class AdamBuilder {
        private ConvergenceChecker convergenceChecker = new RelativeConvergenceChecker(ConvergenceChecker.Norm.L2, 1.0E-6d);
        private int maxEvaluations = Integer.MAX_VALUE;
        private double alpha = 0.001d;
        private double beta1 = 0.9d;
        private double beta2 = 0.999d;
        private double epsilon = 1.0E-8d;

        public AdamBuilder maxEvaluations(int i) {
            if (i <= 0) {
                throw new NotStrictlyPositiveException(Integer.valueOf(i));
            }
            this.maxEvaluations = i;
            return this;
        }

        public AdamBuilder convergenceChecker(ConvergenceChecker convergenceChecker) {
            this.convergenceChecker = convergenceChecker;
            return this;
        }

        public AdamBuilder alpha(double d) {
            if (d <= 0.0d) {
                throw new NotStrictlyPositiveException(Double.valueOf(d));
            }
            this.alpha = d;
            return this;
        }

        public AdamBuilder beta1(double d) {
            if (d < 0.0d || d >= 1.0d) {
                throw new IllegalArgumentException("beta1 must be between 0 (inclusive) and 1 (exclusive)");
            }
            this.beta1 = d;
            return this;
        }

        public AdamBuilder beta2(double d) {
            if (d < 0.0d || d >= 1.0d) {
                throw new IllegalArgumentException("beta2 must be between 0 (inclusive) and 1 (exclusive)");
            }
            this.beta2 = d;
            return this;
        }

        public AdamBuilder epsilon(double d) {
            if (d <= 0.0d) {
                throw new NotStrictlyPositiveException(Double.valueOf(d));
            }
            this.epsilon = d;
            return this;
        }

        public Adam build() {
            return new Adam(this.convergenceChecker, this.maxEvaluations, this.alpha, this.beta1, this.beta2, this.epsilon);
        }

        public String toString() {
            return "Adam.AdamBuilder(convergenceChecker=" + this.convergenceChecker + ", maxEvaluations=" + this.maxEvaluations + ", alpha=" + this.alpha + ", beta1=" + this.beta1 + ", beta2=" + this.beta2 + ", epsilon=" + this.epsilon + ")";
        }
    }

    /* loaded from: input_file:io/improbable/keanu/algorithms/variational/optimizer/gradient/Adam$AdamStatistics.class */
    public static class AdamStatistics {
        private final boolean converged;

        public boolean didConverge() {
            return this.converged;
        }

        public AdamStatistics(boolean z) {
            this.converged = z;
        }
    }

    public static AdamBuilder builder() {
        return new AdamBuilder();
    }

    @Override // io.improbable.keanu.algorithms.variational.optimizer.gradient.GradientOptimizationAlgorithm
    public OptimizedResult optimize(List<? extends Variable> list, FitnessFunction fitnessFunction, FitnessFunctionGradient fitnessFunctionGradient) {
        DoubleTensor[] theta = getTheta(list);
        DoubleTensor[] zeros = getZeros(theta);
        DoubleTensor[] zeros2 = getZeros(theta);
        DoubleTensor[] zeros3 = getZeros(theta);
        boolean z = false;
        HashMap hashMap = new HashMap();
        DoubleTensor[] doubleTensorArr = new DoubleTensor[theta.length];
        double d = 1.0d;
        double d2 = 1.0d;
        for (int i = 1; !z && i <= this.maxEvaluations; i++) {
            updateGradients(list, theta, hashMap, doubleTensorArr, fitnessFunctionGradient);
            d *= this.beta1;
            d2 *= this.beta2;
            double sqrt = (1.0d - d) / Math.sqrt(1.0d - d2);
            for (int i2 = 0; i2 < theta.length; i2++) {
                zeros2[i2] = (DoubleTensor) zeros2[i2].times2(this.beta1).plusInPlace(doubleTensorArr[i2].times2(1.0d - this.beta1));
                zeros3[i2] = (DoubleTensor) zeros3[i2].times2(this.beta2).plusInPlace((DoubleTensor) doubleTensorArr[i2].pow2(2.0d).timesInPlace((DoubleTensor) Double.valueOf(1.0d - this.beta2)));
                zeros[i2] = (DoubleTensor) theta[i2].plus((DoubleTensor) zeros2[i2].times2(this.alpha).divInPlace((DoubleTensor) ((DoubleTensor) zeros3[i2].sqrt().timesInPlace((DoubleTensor) Double.valueOf(sqrt))).plusInPlace((DoubleTensor) Double.valueOf(this.epsilon))));
            }
            z = this.convergenceChecker.hasConverged(theta, zeros);
            DoubleTensor[] doubleTensorArr2 = theta;
            theta = zeros;
            zeros = doubleTensorArr2;
        }
        updatePoint(list, theta, hashMap);
        double fitnessAt = fitnessFunction.getFitnessAt(hashMap);
        this.statistics = new AdamStatistics(z);
        return new OptimizedResult(hashMap, fitnessAt);
    }

    private void updateGradients(List<? extends Variable> list, DoubleTensor[] doubleTensorArr, Map<VariableReference, DoubleTensor> map, DoubleTensor[] doubleTensorArr2, FitnessFunctionGradient fitnessFunctionGradient) {
        updatePoint(list, doubleTensorArr, map);
        updateGradients(list, fitnessFunctionGradient.getGradientsAt(map), doubleTensorArr2);
    }

    private DoubleTensor[] getTheta(List<? extends Variable> list) {
        DoubleTensor[] doubleTensorArr = new DoubleTensor[list.size()];
        for (int i = 0; i < doubleTensorArr.length; i++) {
            doubleTensorArr[i] = (DoubleTensor) list.get(i).getValue();
        }
        return doubleTensorArr;
    }

    private DoubleTensor[] getZeros(DoubleTensor[] doubleTensorArr) {
        DoubleTensor[] doubleTensorArr2 = new DoubleTensor[doubleTensorArr.length];
        for (int i = 0; i < doubleTensorArr2.length; i++) {
            doubleTensorArr2[i] = DoubleTensor.zeros(doubleTensorArr[i].getShape());
        }
        return doubleTensorArr2;
    }

    private void updateGradients(List<? extends Variable> list, Map<? extends VariableReference, DoubleTensor> map, DoubleTensor[] doubleTensorArr) {
        for (int i = 0; i < list.size(); i++) {
            doubleTensorArr[i] = map.get(list.get(i).getReference());
        }
    }

    private void updatePoint(List<? extends Variable> list, DoubleTensor[] doubleTensorArr, Map<VariableReference, DoubleTensor> map) {
        for (int i = 0; i < doubleTensorArr.length; i++) {
            map.put(list.get(i).getReference(), doubleTensorArr[i]);
        }
    }

    private Adam(ConvergenceChecker convergenceChecker, int i, double d, double d2, double d3, double d4) {
        this.convergenceChecker = convergenceChecker;
        this.maxEvaluations = i;
        this.alpha = d;
        this.beta1 = d2;
        this.beta2 = d3;
        this.epsilon = d4;
    }

    public AdamStatistics getStatistics() {
        return this.statistics;
    }
}
