/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.variational.optimizer.gradient;

import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.variational.optimizer.ConvergenceChecker;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunction;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunctionGradient;
import io.improbable.keanu.algorithms.variational.optimizer.OptimizedResult;
import io.improbable.keanu.algorithms.variational.optimizer.RelativeConvergenceChecker;
import io.improbable.keanu.algorithms.variational.optimizer.gradient.GradientOptimizationAlgorithm;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;

public class Adam
implements GradientOptimizationAlgorithm {
    private final ConvergenceChecker convergenceChecker;
    private final int maxEvaluations;
    private final double alpha;
    private final double beta1;
    private final double beta2;
    private final double epsilon;
    private AdamStatistics statistics;

    public static AdamBuilder builder() {
        return new AdamBuilder();
    }

    @Override
    public OptimizedResult optimize(List<? extends Variable> latentVariables, FitnessFunction fitnessFunction, FitnessFunctionGradient fitnessFunctionGradient) {
        DoubleTensor[] theta = this.getTheta(latentVariables);
        DoubleTensor[] thetaNext = this.getZeros(theta);
        DoubleTensor[] m = this.getZeros(theta);
        DoubleTensor[] v = this.getZeros(theta);
        boolean converged = false;
        HashMap<VariableReference, DoubleTensor> thetaAsPoint = new HashMap<VariableReference, DoubleTensor>();
        DoubleTensor[] gradients = new DoubleTensor[theta.length];
        double beta1T = 1.0;
        double beta2T = 1.0;
        for (int t = 1; !converged && t <= this.maxEvaluations; ++t) {
            this.updateGradients(latentVariables, theta, thetaAsPoint, gradients, fitnessFunctionGradient);
            double b = (1.0 - (beta1T *= this.beta1)) / Math.sqrt(1.0 - (beta2T *= this.beta2));
            for (int i = 0; i < theta.length; ++i) {
                m[i] = m[i].times(this.beta1).plusInPlace(gradients[i].times(1.0 - this.beta1));
                v[i] = (DoubleTensor)((Object)v[i].times(this.beta2).plusInPlace(gradients[i].pow(2.0).timesInPlace(1.0 - this.beta2)));
                thetaNext[i] = (DoubleTensor)((Object)theta[i].plus(m[i].times(this.alpha).divInPlace(((DoubleTensor)((Object)((DoubleTensor)v[i].sqrt()).timesInPlace(b))).plusInPlace(this.epsilon))));
            }
            converged = this.convergenceChecker.hasConverged(theta, thetaNext);
            DoubleTensor[] temp = theta;
            theta = thetaNext;
            thetaNext = temp;
        }
        this.updatePoint(latentVariables, theta, thetaAsPoint);
        double logProb = fitnessFunction.getFitnessAt(thetaAsPoint);
        this.statistics = new AdamStatistics(converged);
        return new OptimizedResult(thetaAsPoint, logProb);
    }

    private void updateGradients(List<? extends Variable> ordered, DoubleTensor[] theta, Map<VariableReference, DoubleTensor> thetaAsPoint, DoubleTensor[] gradients, FitnessFunctionGradient fitnessFunctionGradient) {
        this.updatePoint(ordered, theta, thetaAsPoint);
        this.updateGradients(ordered, fitnessFunctionGradient.getGradientsAt(thetaAsPoint), gradients);
    }

    private DoubleTensor[] getTheta(List<? extends Variable> latentVertices) {
        DoubleTensor[] theta = new DoubleTensor[latentVertices.size()];
        for (int i = 0; i < theta.length; ++i) {
            theta[i] = (DoubleTensor)latentVertices.get(i).getValue();
        }
        return theta;
    }

    private DoubleTensor[] getZeros(DoubleTensor[] values) {
        DoubleTensor[] zeros = new DoubleTensor[values.length];
        for (int i = 0; i < zeros.length; ++i) {
            zeros[i] = DoubleTensor.zeros(values[i].getShape());
        }
        return zeros;
    }

    private void updateGradients(List<? extends Variable> ordered, Map<? extends VariableReference, DoubleTensor> newGradients, DoubleTensor[] gradients) {
        for (int i = 0; i < ordered.size(); ++i) {
            gradients[i] = newGradients.get(ordered.get(i).getReference());
        }
    }

    private void updatePoint(List<? extends Variable> ordered, DoubleTensor[] values, Map<VariableReference, DoubleTensor> valuesAsPoint) {
        for (int i = 0; i < values.length; ++i) {
            valuesAsPoint.put(ordered.get(i).getReference(), values[i]);
        }
    }

    private Adam(ConvergenceChecker convergenceChecker, int maxEvaluations, double alpha, double beta1, double beta2, double epsilon) {
        this.convergenceChecker = convergenceChecker;
        this.maxEvaluations = maxEvaluations;
        this.alpha = alpha;
        this.beta1 = beta1;
        this.beta2 = beta2;
        this.epsilon = epsilon;
    }

    public AdamStatistics getStatistics() {
        return this.statistics;
    }

    public static class AdamBuilder {
        private ConvergenceChecker convergenceChecker = new RelativeConvergenceChecker(ConvergenceChecker.Norm.L2, 1.0E-6);
        private int maxEvaluations = Integer.MAX_VALUE;
        private double alpha = 0.001;
        private double beta1 = 0.9;
        private double beta2 = 0.999;
        private double epsilon = 1.0E-8;

        public AdamBuilder maxEvaluations(int maxEvaluations) {
            if (maxEvaluations <= 0) {
                throw new NotStrictlyPositiveException((Number)maxEvaluations);
            }
            this.maxEvaluations = maxEvaluations;
            return this;
        }

        public AdamBuilder convergenceChecker(ConvergenceChecker convergenceChecker) {
            this.convergenceChecker = convergenceChecker;
            return this;
        }

        public AdamBuilder alpha(double alpha) {
            if (alpha <= 0.0) {
                throw new NotStrictlyPositiveException((Number)alpha);
            }
            this.alpha = alpha;
            return this;
        }

        public AdamBuilder beta1(double beta1) {
            if (beta1 < 0.0 || beta1 >= 1.0) {
                throw new IllegalArgumentException("beta1 must be between 0 (inclusive) and 1 (exclusive)");
            }
            this.beta1 = beta1;
            return this;
        }

        public AdamBuilder beta2(double beta2) {
            if (beta2 < 0.0 || beta2 >= 1.0) {
                throw new IllegalArgumentException("beta2 must be between 0 (inclusive) and 1 (exclusive)");
            }
            this.beta2 = beta2;
            return this;
        }

        public AdamBuilder epsilon(double epsilon) {
            if (epsilon <= 0.0) {
                throw new NotStrictlyPositiveException((Number)epsilon);
            }
            this.epsilon = epsilon;
            return this;
        }

        public Adam build() {
            return new Adam(this.convergenceChecker, this.maxEvaluations, this.alpha, this.beta1, this.beta2, this.epsilon);
        }

        public String toString() {
            return "Adam.AdamBuilder(convergenceChecker=" + this.convergenceChecker + ", maxEvaluations=" + this.maxEvaluations + ", alpha=" + this.alpha + ", beta1=" + this.beta1 + ", beta2=" + this.beta2 + ", epsilon=" + this.epsilon + ")";
        }
    }

    public static class AdamStatistics {
        private final boolean converged;

        public boolean didConverge() {
            return this.converged;
        }

        public AdamStatistics(boolean converged) {
            this.converged = converged;
        }
    }
}

