/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.learning;

import java.io.Serializable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.ops.transforms.Transforms;

public class AdaDelta
implements Serializable,
GradientUpdater {
    private INDArray msg;
    private INDArray msdx;
    private double rho = 0.95;

    public AdaDelta(double rho) {
        this.rho = rho;
    }

    public AdaDelta() {
        this.rho = 0.95;
    }

    public INDArray getMsg() {
        return this.msg;
    }

    public void setMsg(INDArray msg) {
        this.msg = msg;
    }

    public INDArray getMsdx() {
        return this.msdx;
    }

    public void setMsdx(INDArray msdx) {
        this.msdx = msdx;
    }

    public double getRho() {
        return this.rho;
    }

    public void setRho(double rho) {
        this.rho = rho;
    }

    @Override
    public INDArray getGradient(INDArray gradient, int iteration) {
        if (this.msg == null) {
            this.msg = Nd4j.zeros(gradient.shape());
        }
        if (this.msdx == null) {
            this.msdx = Nd4j.zeros(gradient.shape());
        }
        this.msg.muli(this.rho);
        this.msg.addi(1.0 - this.rho).muli(gradient.mul(gradient));
        INDArray ret = Transforms.sqrt(this.msdx.add(Nd4j.EPS_THRESHOLD)).divi(Transforms.sqrt(this.msg.add(Nd4j.EPS_THRESHOLD))).muli(gradient);
        this.msdx.muli(this.rho);
        INDArray dxSquared = ret.mul(ret);
        this.msdx.addi(dxSquared.muli(1.0 - this.rho));
        return ret;
    }
}

