/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize;

import java.util.Map;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GradientAdjustment {
    private static final Logger log = LoggerFactory.getLogger(GradientAdjustment.class);

    private GradientAdjustment() {
    }

    public static void updateGradientAccordingToParams(NeuralNetConfiguration conf, int iteration, Gradient gradient, int batchSize, Map<String, AdaGrad> adaGrad, Model model) {
        for (String variable : gradient.gradientForVariable().keySet()) {
            AdaGrad adaGradForVariable = adaGrad.get(variable);
            if (adaGradForVariable == null) {
                adaGradForVariable = new AdaGrad(model.getParam(variable).shape());
                adaGrad.put(variable, adaGradForVariable);
            } else {
                adaGradForVariable = adaGrad.get(variable);
            }
            GradientAdjustment.updateGradientAccordingToParams(conf, iteration, adaGradForVariable, gradient.getGradientFor(variable), model.getParam(variable), batchSize);
        }
    }

    public static void updateGradientAccordingToParams(NeuralNetConfiguration conf, int iteration, AdaGrad adaGrad, INDArray gradient, INDArray params, int batchSize) {
        int key;
        if (adaGrad == null) {
            adaGrad = new AdaGrad(gradient.shape());
        }
        if (iteration != 0 && conf.getResetAdaGradIterations() > 0 && iteration % conf.getResetAdaGradIterations() == 0) {
            adaGrad.historicalGradient = null;
            log.info("Resetting adagrad");
        }
        double momentum = conf.getMomentum();
        if (conf.getMomentumAfter() != null && !conf.getMomentumAfter().isEmpty() && iteration >= (key = conf.getMomentumAfter().keySet().iterator().next().intValue())) {
            momentum = conf.getMomentumAfter().get(key);
        }
        if (conf.isUseAdaGrad()) {
            gradient = adaGrad.getGradient(gradient);
        } else {
            gradient.muli((Number)conf.getLr());
        }
        if (momentum > 0.0) {
            gradient.addi(gradient.mul((Number)momentum).addi(gradient.mul((Number)(1.0 - momentum))));
        }
        if (conf.isUseRegularization() && conf.getL2() > 0.0) {
            gradient.subi(params.mul((Number)(conf.getL2() * conf.getLr())));
        } else if (conf.isUseRegularization() && conf.getL1() < 0.0) {
            gradient.muli(Transforms.sign((INDArray)params)).muli((Number)conf.getL1());
        }
        if (conf.isConstrainGradientToUnitNorm()) {
            gradient.divi(gradient.norm2(Integer.MAX_VALUE));
        }
        gradient.divi((Number)batchSize);
    }
}

