/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater;

import com.google.common.base.Function;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.impl.accum.Norm2;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.AbsValueGreaterThan;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.GradientUpdaterAggregator;
import org.nd4j.linalg.ops.transforms.Transforms;

public abstract class BaseUpdater
implements Updater {
    protected Map<String, GradientUpdater> updaterForVariable = new HashMap<String, GradientUpdater>();

    @Override
    public void update(Layer layer, Gradient gradient, int iteration, int miniBatchSize) {
        this.preApply(layer, gradient, iteration);
        for (Map.Entry<String, INDArray> gradientPair : gradient.gradientForVariable().entrySet()) {
            String paramName = gradientPair.getKey();
            INDArray paramVal = gradientPair.getValue();
            LearningRatePolicy decay = layer.conf().getLearningRatePolicy();
            if (decay != LearningRatePolicy.None || layer.conf().getLayer().getUpdater() == org.deeplearning4j.nn.conf.Updater.NESTEROVS) {
                this.applyLrDecayPolicy(decay, layer, iteration, paramName);
            }
            GradientUpdater updater = this.init(paramName, paramVal, layer);
            INDArray gradient2 = updater.getGradient(paramVal, iteration);
            this.postApply(layer, gradient2, paramName, miniBatchSize);
            gradient.setGradientFor(paramName, gradient2);
        }
    }

    public void postApply(Layer layer, INDArray gradient, String param, int miniBatchSize) {
        NeuralNetConfiguration conf = layer.conf();
        INDArray params = layer.getParam(param);
        if (conf.isUseRegularization() && conf.getL2ByParam(param) > 0.0) {
            gradient.addi(params.mul((Number)conf.getL2ByParam(param)));
        }
        if (conf.isUseRegularization() && conf.getL1ByParam(param) > 0.0) {
            gradient.addi(Transforms.sign((INDArray)params).muli((Number)conf.getL1ByParam(param)));
        }
        if (conf.isMiniBatch()) {
            gradient.divi((Number)miniBatchSize);
        }
    }

    public void applyMomentumDecayPolicy(Layer layer, int iteration, String variable) {
        NeuralNetConfiguration conf = layer.conf();
        if (conf.getLayer().getMomentumSchedule().containsKey(iteration)) {
            conf.getLayer().setMomentum(conf.getLayer().getMomentumSchedule().get(iteration));
            if (this.updaterForVariable.get(variable) != null) {
                this.updaterForVariable.get(variable).update(new Object[]{conf.getLearningRateByParam(variable), conf.getLayer().getMomentumSchedule().get(iteration)});
            }
        }
    }

    public void applyLrDecayPolicy(LearningRatePolicy decay, Layer layer, int iteration, String variable) {
        NeuralNetConfiguration conf = layer.conf();
        double decayRate = layer.conf().getLrPolicyDecayRate();
        double lr = conf.getLearningRateByParam(variable);
        switch (decay) {
            case Exponential: {
                conf.setLearningRateByParam(variable, lr * Math.pow(decayRate, iteration));
                break;
            }
            case Inverse: {
                conf.setLearningRateByParam(variable, lr / Math.pow(1.0 + decayRate * (double)iteration, conf.getLrPolicyPower()));
                break;
            }
            case Step: {
                conf.setLearningRateByParam(variable, lr * Math.pow(decayRate, Math.floor((double)iteration / conf.getLrPolicySteps())));
                break;
            }
            case Poly: {
                conf.setLearningRateByParam(variable, lr * Math.pow(1.0 - (double)iteration / (double)conf.getNumIterations(), conf.getLrPolicyPower()));
                break;
            }
            case Sigmoid: {
                conf.setLearningRateByParam(variable, lr / (1.0 + Math.exp(-decayRate * ((double)iteration - conf.getLrPolicySteps()))));
                break;
            }
            case Schedule: {
                if (!conf.getLayer().getLearningRateSchedule().containsKey(iteration)) break;
                conf.setLearningRateByParam(variable, conf.getLayer().getLearningRateSchedule().get(iteration));
            }
        }
        if (layer.conf().getLayer().getUpdater() == org.deeplearning4j.nn.conf.Updater.NESTEROVS) {
            this.applyMomentumDecayPolicy(layer, iteration, variable);
        } else if (this.updaterForVariable.get(variable) != null) {
            this.updaterForVariable.get(variable).update(new Object[]{conf.getLearningRateByParam(variable)});
        }
    }

    public void preApply(Layer layer, Gradient gradient, int iteration) {
        GradientNormalization normalization = layer.conf().getLayer().getGradientNormalization();
        if (normalization == null || normalization == GradientNormalization.None) {
            return;
        }
        final double threshold = layer.conf().getLayer().getGradientNormalizationThreshold();
        switch (normalization) {
            case RenormalizeL2PerLayer: {
                double sumSquares = 0.0;
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = g.norm2Number().doubleValue();
                    sumSquares += l2 * l2;
                }
                double layerL2 = FastMath.sqrt((double)sumSquares);
                for (INDArray g : gradient.gradientForVariable().values()) {
                    g.divi((Number)layerL2);
                }
                break;
            }
            case RenormalizeL2PerParamType: {
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = Nd4j.getExecutioner().execAndReturn((Accumulation)new Norm2(g)).getFinalResult().doubleValue();
                    g.divi((Number)l2);
                }
                break;
            }
            case ClipElementWiseAbsoluteValue: {
                AbsValueGreaterThan absValueCondition = new AbsValueGreaterThan((Number)threshold);
                Function<Number, Number> clipFn = new Function<Number, Number>(){

                    public Number apply(Number number) {
                        return number.doubleValue() > threshold ? threshold : -threshold;
                    }
                };
                for (INDArray g : gradient.gradientForVariable().values()) {
                    BooleanIndexing.applyWhere((INDArray)g, (Condition)absValueCondition, (Function)clipFn);
                }
                break;
            }
            case ClipL2PerLayer: {
                double sumSquares2 = 0.0;
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = Nd4j.getExecutioner().execAndReturn((Accumulation)new Norm2(g)).getFinalResult().doubleValue();
                    sumSquares2 += l2 * l2;
                }
                double layerL22 = FastMath.sqrt((double)sumSquares2);
                if (!(layerL22 > threshold)) break;
                double scalingFactor = threshold / layerL22;
                for (INDArray g : gradient.gradientForVariable().values()) {
                    g.muli((Number)scalingFactor);
                }
                break;
            }
            case ClipL2PerParamType: {
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = g.norm2Number().doubleValue();
                    if (!(l2 > threshold)) continue;
                    double scalingFactor = l2 / threshold;
                    g.divi((Number)scalingFactor);
                }
                break;
            }
            default: {
                throw new RuntimeException("Unknown (or not implemented) gradient normalization strategy: " + (Object)((Object)normalization));
            }
        }
    }

    public abstract void init();

    public abstract GradientUpdater init(String var1, INDArray var2, Layer var3);

    public boolean equals(Object other) {
        if (!(other instanceof BaseUpdater)) {
            return false;
        }
        return this.updaterForVariable.equals(((BaseUpdater)other).updaterForVariable);
    }

    @Override
    public Updater clone() {
        BaseUpdater updater;
        HashMap<String, GradientUpdater> newMap = new HashMap<String, GradientUpdater>();
        for (String s : this.updaterForVariable.keySet()) {
            newMap.put(s, this.updaterForVariable.get(s).getAggregator(true).getUpdater());
        }
        try {
            updater = (BaseUpdater)this.getClass().getConstructor(new Class[0]).newInstance(new Object[0]);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        updater.updaterForVariable = newMap;
        return updater;
    }

    protected static abstract class UpdaterAggregatorImpl
    implements UpdaterAggregator {
        protected Map<String, GradientUpdaterAggregator> aggregatorMap = new LinkedHashMap<String, GradientUpdaterAggregator>();

        protected UpdaterAggregatorImpl() {
        }

        @Override
        public void aggregate(Updater updater) {
            BaseUpdater bu = (BaseUpdater)updater;
            for (String s : bu.updaterForVariable.keySet()) {
                GradientUpdaterAggregator ag = this.aggregatorMap.get(s);
                GradientUpdater guToAdd = bu.updaterForVariable.get(s);
                if (ag == null) {
                    ag = guToAdd.getAggregator(true);
                    this.aggregatorMap.put(s, ag);
                    continue;
                }
                ag.aggregate(guToAdd);
            }
        }

        @Override
        public void merge(UpdaterAggregator aggregator) {
            UpdaterAggregatorImpl ag = (UpdaterAggregatorImpl)aggregator;
            if (this.aggregatorMap == null) {
                this.aggregatorMap = ag.aggregatorMap;
            } else {
                if (ag.aggregatorMap == null) {
                    return;
                }
                for (String s : ag.aggregatorMap.keySet()) {
                    GradientUpdaterAggregator first = this.aggregatorMap.get(s);
                    GradientUpdaterAggregator second = ag.aggregatorMap.get(s);
                    first.combine(second);
                }
            }
        }

        protected Updater setUpdaterState(BaseUpdater updater) {
            updater.updaterForVariable = new LinkedHashMap<String, GradientUpdater>();
            for (String s : this.aggregatorMap.keySet()) {
                updater.updaterForVariable.put(s, this.aggregatorMap.get(s).getUpdater());
            }
            return updater;
        }
    }
}

