/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.learning;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.ops.transforms.Transforms;

public class RmsPropUpdater
implements GradientUpdater {
    private INDArray lastGradient;
    private double rmsDecay = 0.5;
    private double lr = 0.1;

    public RmsPropUpdater(double lr, double rmsDecay) {
        this.lr = lr;
        this.rmsDecay = rmsDecay;
    }

    public void setRmsDecay(double rmsDecay) {
        this.rmsDecay = rmsDecay;
    }

    public double getRmsDecay() {
        return this.rmsDecay;
    }

    public void setLR(double lr) {
        this.lr = lr;
    }

    public double getLR() {
        return this.lr;
    }

    @Override
    public INDArray getGradient(INDArray gradient, int iteration) {
        if (this.lastGradient == null) {
            this.lastGradient = Nd4j.zeros(gradient.shape());
        }
        this.lastGradient.muli(this.rmsDecay).addi(gradient.mul(gradient).muli(1.0 - this.rmsDecay));
        INDArray ret = gradient.mul(this.lr).divi(Transforms.sqrt(this.lastGradient.add(Nd4j.EPS_THRESHOLD)));
        return ret;
    }

    public RmsPropUpdater() {
    }
}

