/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.learning;

import java.io.Serializable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.Shape;

public class AdaGrad
implements Serializable {
    protected static final long serialVersionUID = -4754127927704099888L;
    public INDArray historicalGradient;
    public int[] shape;
    protected double masterStepSize = 0.1;
    protected int numIterations = 0;
    protected boolean decayLr;

    public AdaGrad(int rows, int cols, double gamma) {
        this.shape = new int[]{rows, cols};
        this.masterStepSize = gamma;
        this.decayLr = false;
    }

    public AdaGrad(int[] shape) {
        this.shape = shape;
        this.masterStepSize = 0.1;
        this.decayLr = false;
    }

    public AdaGrad(int rows, int cols) {
        this(rows, cols, 0.1);
    }

    public double getGradient(double gradient, int column, int[] shape) {
        boolean historicalInitialized = false;
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.ones(shape);
            historicalInitialized = true;
        }
        double sqrtHistory = !historicalInitialized ? Math.sqrt(this.historicalGradient.getDouble(column)) : this.historicalGradient.getDouble(column);
        double learningRates = this.masterStepSize / (sqrtHistory + Nd4j.EPS_THRESHOLD);
        double adjustedGradient = gradient * learningRates;
        this.historicalGradient.putScalar(column, this.historicalGradient.getDouble(column) + Math.pow(gradient, 2.0));
        ++this.numIterations;
        return adjustedGradient;
    }

    public AdaGrad createSubset(int index) {
        INDArray slice;
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.ones(this.shape);
        }
        if (Shape.isMatrix(this.shape)) {
            INDArray slice2;
            AdaGrad a = new AdaGrad(1, this.historicalGradient.columns());
            a.historicalGradient = slice2 = this.historicalGradient.slice(index).dup();
            a.setMasterStepSize(this.masterStepSize);
            a.setDecayLr(this.decayLr);
            return a;
        }
        AdaGrad a = new AdaGrad(1, 1);
        a.historicalGradient = slice = Nd4j.scalar(this.historicalGradient.getDouble(index));
        a.setMasterStepSize(this.masterStepSize);
        a.setDecayLr(this.decayLr);
        return a;
    }

    public INDArray getGradient(INDArray gradient, int slice, int[] shape) {
        boolean historicalInitialized = false;
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.ones(shape);
            historicalInitialized = true;
        } else if (!this.historicalGradient.isVector() && this.historicalGradient.slice(slice).length() != gradient.length()) {
            throw new IllegalArgumentException("Illegal gradient");
        }
        INDArray sqrtHistory = null;
        sqrtHistory = this.historicalGradient.isVector() ? Transforms.sqrt(this.historicalGradient) : (!historicalInitialized ? Transforms.sqrt(this.historicalGradient.slice(slice)) : this.historicalGradient);
        INDArray learningRates = sqrtHistory.add(Nd4j.EPS_THRESHOLD).rdivi(this.masterStepSize);
        gradient.muli(learningRates);
        this.historicalGradient.slice(slice).addi(Transforms.pow(gradient, 2));
        ++this.numIterations;
        return gradient;
    }

    public INDArray getGradient(INDArray gradient) {
        boolean historicalInitialized = false;
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.ones(gradient.rows(), gradient.columns());
            historicalInitialized = true;
        } else if (this.historicalGradient.length() != gradient.length()) {
            throw new IllegalArgumentException("Illegal gradient");
        }
        INDArray sqrtHistory = !historicalInitialized ? Transforms.sqrt(this.historicalGradient) : this.historicalGradient;
        INDArray learningRates = sqrtHistory.add(Nd4j.EPS_THRESHOLD).rdivi(this.masterStepSize);
        gradient.muli(learningRates);
        this.historicalGradient.addi(Transforms.pow(gradient, 2));
        ++this.numIterations;
        return gradient;
    }

    public INDArray getHistoricalGradient() {
        return this.historicalGradient;
    }

    public void setHistoricalGradient(INDArray historicalGradient) {
        this.historicalGradient = historicalGradient;
    }

    public double getMasterStepSize() {
        return this.masterStepSize;
    }

    public void setMasterStepSize(double masterStepSize) {
        this.masterStepSize = masterStepSize;
    }

    public boolean isDecayLr() {
        return this.decayLr;
    }

    public void setDecayLr(boolean decayLr) {
        this.decayLr = decayLr;
    }
}

