/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater;

import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.impl.accum.Norm2;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.learning.AdaDelta;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.learning.Adam;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.Nesterovs;
import org.nd4j.linalg.learning.NoOpUpdater;
import org.nd4j.linalg.learning.RmsProp;
import org.nd4j.linalg.learning.Sgd;
import org.nd4j.linalg.ops.transforms.Transforms;

public class LayerUpdater
implements Updater {
    protected Map<String, GradientUpdater> updaterForVariable = new LinkedHashMap<String, GradientUpdater>();
    protected INDArray viewArray;

    @Override
    public void setStateViewArray(Layer layer, INDArray viewArray, boolean initialize) {
        Map<String, INDArray> params = layer.paramTable();
        int count = 0;
        for (Map.Entry<String, INDArray> entry : params.entrySet()) {
            INDArray paramsArray = entry.getValue();
            GradientUpdater gu = this.init(entry.getKey(), layer);
            int thisSize = gu.stateSizeForInputSize(entry.getValue().length());
            if (thisSize == 0) continue;
            INDArray subset = viewArray.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)count, (int)(count + thisSize))});
            gu.setStateViewArray(subset, paramsArray.shape(), paramsArray.ordering(), initialize);
            count += thisSize;
        }
    }

    public Map<String, GradientUpdater> getUpdaterForVariable() {
        return this.updaterForVariable;
    }

    @Override
    public INDArray getStateViewArray() {
        return this.viewArray;
    }

    @Override
    public int stateSizeForLayer(Layer layer) {
        Map<String, INDArray> params = layer.paramTable();
        int count = 0;
        for (Map.Entry<String, INDArray> entry : params.entrySet()) {
            GradientUpdater gu = this.init(entry.getKey(), layer);
            count += gu.stateSizeForInputSize(entry.getValue().length());
        }
        return count;
    }

    @Override
    public void update(Layer layer, Gradient gradient, int iteration, int miniBatchSize) {
        this.preApply(layer, gradient, iteration);
        for (Map.Entry<String, INDArray> gradientPair : gradient.gradientForVariable().entrySet()) {
            String paramName = gradientPair.getKey();
            if (!layer.conf().isPretrain() && "vb".equals(paramName.split("_")[0])) continue;
            INDArray gradientOrig = gradientPair.getValue();
            LearningRatePolicy decay = layer.conf().getLearningRatePolicy();
            if (decay != LearningRatePolicy.None || layer.conf().getLayer().getUpdater() == org.deeplearning4j.nn.conf.Updater.NESTEROVS) {
                this.applyLrDecayPolicy(decay, layer, iteration, paramName);
            }
            GradientUpdater updater = this.init(paramName, layer);
            INDArray gradient2 = updater.getGradient(gradientOrig, iteration);
            this.postApply(layer, gradient2, paramName, miniBatchSize);
            gradient.setGradientFor(paramName, gradient2);
        }
    }

    public void postApply(Layer layer, INDArray gradient, String param, int miniBatchSize) {
        NeuralNetConfiguration conf = layer.conf();
        INDArray params = layer.getParam(param);
        if (conf.isUseRegularization() && conf.getL2ByParam(param) > 0.0) {
            gradient.addi(params.mul((Number)conf.getL2ByParam(param)));
        }
        if (conf.isUseRegularization() && conf.getL1ByParam(param) > 0.0) {
            gradient.addi(Transforms.sign((INDArray)params).muli((Number)conf.getL1ByParam(param)));
        }
        if (conf.isMiniBatch()) {
            gradient.divi((Number)miniBatchSize);
        }
    }

    public void applyMomentumDecayPolicy(Layer layer, int iteration, String variable) {
        NeuralNetConfiguration conf = layer.conf();
        if (conf.getLayer().getMomentumSchedule().containsKey(iteration)) {
            conf.getLayer().setMomentum(conf.getLayer().getMomentumSchedule().get(iteration));
            if (this.updaterForVariable.get(variable) != null) {
                this.updaterForVariable.get(variable).update(new Object[]{conf.getLearningRateByParam(variable), conf.getLayer().getMomentumSchedule().get(iteration)});
            }
        } else if (this.updaterForVariable.get(variable) != null) {
            this.updaterForVariable.get(variable).update(new Object[]{conf.getLearningRateByParam(variable), conf.getLayer().getMomentum()});
        }
    }

    public void applyLrDecayPolicy(LearningRatePolicy decay, Layer layer, int iteration, String variable) {
        NeuralNetConfiguration conf = layer.conf();
        double decayRate = layer.conf().getLrPolicyDecayRate();
        double lr = conf.getLearningRateByParam(variable);
        switch (decay) {
            case Exponential: {
                conf.setLearningRateByParam(variable, lr * Math.pow(decayRate, iteration));
                break;
            }
            case Inverse: {
                conf.setLearningRateByParam(variable, lr / Math.pow(1.0 + decayRate * (double)iteration, conf.getLrPolicyPower()));
                break;
            }
            case Step: {
                conf.setLearningRateByParam(variable, lr * Math.pow(decayRate, Math.floor((double)iteration / conf.getLrPolicySteps())));
                break;
            }
            case TorchStep: {
                if (iteration <= 1 || conf.getLrPolicySteps() % (double)iteration != 0.0) break;
                conf.setLearningRateByParam(variable, lr * decayRate);
                break;
            }
            case Poly: {
                conf.setLearningRateByParam(variable, lr * Math.pow(1.0 - (double)iteration / (double)conf.getNumIterations(), conf.getLrPolicyPower()));
                break;
            }
            case Sigmoid: {
                conf.setLearningRateByParam(variable, lr / (1.0 + Math.exp(-decayRate * ((double)iteration - conf.getLrPolicySteps()))));
                break;
            }
            case Schedule: {
                if (!conf.getLayer().getLearningRateSchedule().containsKey(iteration)) break;
                conf.setLearningRateByParam(variable, conf.getLayer().getLearningRateSchedule().get(iteration));
            }
        }
        if (layer.conf().getLayer().getUpdater() == org.deeplearning4j.nn.conf.Updater.NESTEROVS) {
            this.applyMomentumDecayPolicy(layer, iteration, variable);
        } else if (this.updaterForVariable.get(variable) != null) {
            this.updaterForVariable.get(variable).update(new Object[]{conf.getLearningRateByParam(variable)});
        }
    }

    public void preApply(Layer layer, Gradient gradient, int iteration) {
        GradientNormalization normalization = layer.conf().getLayer().getGradientNormalization();
        if (normalization == null || normalization == GradientNormalization.None || layer.conf().isPretrain()) {
            return;
        }
        double threshold = layer.conf().getLayer().getGradientNormalizationThreshold();
        switch (normalization) {
            case RenormalizeL2PerLayer: {
                double sumSquares = 0.0;
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = g.norm2Number().doubleValue();
                    sumSquares += l2 * l2;
                }
                double layerL2 = FastMath.sqrt((double)sumSquares);
                for (INDArray g : gradient.gradientForVariable().values()) {
                    g.divi((Number)layerL2);
                }
                break;
            }
            case RenormalizeL2PerParamType: {
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = Nd4j.getExecutioner().execAndReturn((Accumulation)new Norm2(g)).getFinalResult().doubleValue();
                    g.divi((Number)l2);
                }
                break;
            }
            case ClipElementWiseAbsoluteValue: {
                for (INDArray g : gradient.gradientForVariable().values()) {
                    BooleanIndexing.replaceWhere((INDArray)g, (Number)threshold, (Condition)Conditions.greaterThan((Number)threshold));
                    BooleanIndexing.replaceWhere((INDArray)g, (Number)(-threshold), (Condition)Conditions.lessThan((Number)(-threshold)));
                }
                break;
            }
            case ClipL2PerLayer: {
                double sumSquares2 = 0.0;
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = Nd4j.getExecutioner().execAndReturn((Accumulation)new Norm2(g)).getFinalResult().doubleValue();
                    sumSquares2 += l2 * l2;
                }
                double layerL22 = FastMath.sqrt((double)sumSquares2);
                if (!(layerL22 > threshold)) break;
                double scalingFactor = threshold / layerL22;
                for (INDArray g : gradient.gradientForVariable().values()) {
                    g.muli((Number)scalingFactor);
                }
                break;
            }
            case ClipL2PerParamType: {
                for (INDArray g : gradient.gradientForVariable().values()) {
                    double l2 = g.norm2Number().doubleValue();
                    if (!(l2 > threshold)) continue;
                    double scalingFactor = l2 / threshold;
                    g.divi((Number)scalingFactor);
                }
                break;
            }
            default: {
                throw new RuntimeException("Unknown (or not implemented) gradient normalization strategy: " + (Object)((Object)normalization));
            }
        }
    }

    public void init() {
    }

    public GradientUpdater init(String variable, Layer layer) {
        GradientUpdater updater = this.updaterForVariable.get(variable);
        if (updater == null) {
            org.deeplearning4j.nn.conf.Updater u = layer.conf().getLayer().getUpdaterByParam(variable);
            switch (u) {
                case SGD: {
                    updater = new Sgd(layer.conf().getLearningRateByParam(variable));
                    break;
                }
                case ADAM: {
                    updater = new Adam(layer.conf().getLearningRateByParam(variable), layer.conf().getLayer().getAdamMeanDecay(), layer.conf().getLayer().getAdamVarDecay());
                    break;
                }
                case ADADELTA: {
                    updater = new AdaDelta(layer.conf().getLayer().getRho(), layer.conf().getLayer().getEpsilon());
                    break;
                }
                case NESTEROVS: {
                    updater = new Nesterovs(layer.conf().getLayer().getMomentum(), layer.conf().getLearningRateByParam(variable));
                    break;
                }
                case ADAGRAD: {
                    updater = new AdaGrad(layer.conf().getLearningRateByParam(variable), layer.conf().getLayer().getEpsilon());
                    break;
                }
                case RMSPROP: {
                    updater = new RmsProp(layer.conf().getLearningRateByParam(variable), layer.conf().getLayer().getRmsDecay());
                    break;
                }
                case NONE: {
                    updater = new NoOpUpdater();
                    break;
                }
                case CUSTOM: {
                    throw new UnsupportedOperationException("Custom updaters: not yet implemented");
                }
                default: {
                    throw new IllegalArgumentException("Unknown updater: " + (Object)((Object)u));
                }
            }
            this.updaterForVariable.put(variable, updater);
        }
        return updater;
    }

    public boolean equals(Object other) {
        if (!(other instanceof LayerUpdater)) {
            return false;
        }
        return this.updaterForVariable.equals(((LayerUpdater)other).updaterForVariable);
    }

    public int hashCode() {
        int result = 19;
        result = 31 * result + (this.updaterForVariable == null ? 0 : this.updaterForVariable.hashCode());
        return result;
    }

    @Override
    public Updater clone() {
        LayerUpdater updater;
        HashMap<String, GradientUpdater> newMap = new HashMap<String, GradientUpdater>();
        for (Map.Entry<String, GradientUpdater> entry : this.updaterForVariable.entrySet()) {
            newMap.put(entry.getKey(), entry.getValue().getAggregator(true).getUpdater());
        }
        try {
            updater = (LayerUpdater)this.getClass().getConstructor(new Class[0]).newInstance(new Object[0]);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        updater.updaterForVariable = newMap;
        return updater;
    }
}

