/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.learning;

import java.util.HashMap;
import java.util.Map;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.config.AdaDelta;

public class AdaDeltaUpdater
implements GradientUpdater<AdaDelta> {
    public static final String MSG_STATE = "msg";
    public static final String MSDX_STATE = "msdx";
    private final AdaDelta config;
    private INDArray msg;
    private INDArray msdx;

    public AdaDeltaUpdater(AdaDelta config) {
        this.config = config;
    }

    @Override
    public void setState(Map<String, INDArray> stateMap, boolean initialize) {
        if (!stateMap.containsKey(MSG_STATE) || !stateMap.containsKey(MSDX_STATE) || stateMap.size() != 2) {
            throw new IllegalStateException("State map should contain only keys [msg,msdx] but has keys " + stateMap.keySet());
        }
        this.msg = stateMap.get(MSG_STATE);
        this.msdx = stateMap.get(MSDX_STATE);
    }

    @Override
    public Map<String, INDArray> getState() {
        HashMap<String, INDArray> r = new HashMap<String, INDArray>();
        r.put(MSG_STATE, this.msg);
        r.put(MSDX_STATE, this.msdx);
        return r;
    }

    @Override
    public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) {
        if (!viewArray.isRowVector()) {
            throw new IllegalArgumentException("Invalid input: expect row vector input");
        }
        if (initialize) {
            viewArray.assign(0);
        }
        long length = viewArray.length();
        this.msg = viewArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(0L, length / 2L));
        this.msdx = viewArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(length / 2L, length));
        this.msg = Shape.newShapeNoCopy(this.msg, gradientShape, gradientOrder == 'f');
        this.msdx = Shape.newShapeNoCopy(this.msdx, gradientShape, gradientOrder == 'f');
        if (this.msg == null || this.msdx == null) {
            throw new IllegalStateException("Could not correctly reshape gradient view arrays");
        }
    }

    @Override
    public void applyUpdater(INDArray gradient, int iteration, int epoch) {
        if (this.msg == null || this.msdx == null) {
            throw new IllegalStateException("Updater has not been initialized with view state");
        }
        double rho = this.config.getRho();
        double epsilon = this.config.getEpsilon();
        Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, this.msg, this.msdx, rho, epsilon));
    }

    @Override
    public AdaDelta getConfig() {
        return this.config;
    }

    public INDArray getMsg() {
        return this.msg;
    }

    public INDArray getMsdx() {
        return this.msdx;
    }

    public void setMsg(INDArray msg) {
        this.msg = msg;
    }

    public void setMsdx(INDArray msdx) {
        this.msdx = msdx;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof AdaDeltaUpdater)) {
            return false;
        }
        AdaDeltaUpdater other = (AdaDeltaUpdater)o;
        if (!other.canEqual(this)) {
            return false;
        }
        AdaDelta this$config = this.getConfig();
        AdaDelta other$config = other.getConfig();
        if (this$config == null ? other$config != null : !((Object)this$config).equals(other$config)) {
            return false;
        }
        INDArray this$msg = this.getMsg();
        INDArray other$msg = other.getMsg();
        if (this$msg == null ? other$msg != null : !this$msg.equals(other$msg)) {
            return false;
        }
        INDArray this$msdx = this.getMsdx();
        INDArray other$msdx = other.getMsdx();
        return !(this$msdx == null ? other$msdx != null : !this$msdx.equals(other$msdx));
    }

    protected boolean canEqual(Object other) {
        return other instanceof AdaDeltaUpdater;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        AdaDelta $config = this.getConfig();
        result = result * 59 + ($config == null ? 43 : ((Object)$config).hashCode());
        INDArray $msg = this.getMsg();
        result = result * 59 + ($msg == null ? 43 : $msg.hashCode());
        INDArray $msdx = this.getMsdx();
        result = result * 59 + ($msdx == null ? 43 : $msdx.hashCode());
        return result;
    }

    public String toString() {
        return "AdaDeltaUpdater(config=" + this.getConfig() + ", msg=" + this.getMsg() + ", msdx=" + this.getMsdx() + ")";
    }
}

