/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.mxnet.engine.MxNDArray;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.mxnet.jna.MxnetLibrary;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.util.NativeResource;
import com.sun.jna.Pointer;
import java.util.Arrays;

public class MxParameterServer
extends NativeResource<Pointer>
implements ParameterServer {
    private OptimizerCallback callback;
    private int priority;

    public MxParameterServer(Optimizer optimizer) {
        super((Object)MxParameterServer.createdKVStore());
        this.callback = new OptimizerCallback(optimizer);
        JnaUtils.parameterStoreSetUpdater((Pointer)this.getHandle(), null, this.callback, null);
        this.priority = 0;
    }

    public void init(String parameterId, NDArray[] values) {
        Object[] keys = new String[values.length];
        Arrays.fill(keys, parameterId);
        NDList vals = new NDList(values);
        JnaUtils.parameterStoreInit((Pointer)this.getHandle(), values.length, (String[])keys, vals);
    }

    public void update(String parameterId, NDArray[] grads, NDArray[] params) {
        Object[] gradKeys = new String[grads.length];
        Object[] paramKeys = new String[params.length];
        Arrays.fill(gradKeys, parameterId);
        Arrays.fill(paramKeys, parameterId);
        JnaUtils.parameterStorePushPull((Pointer)this.getHandle(), grads.length, (String[])gradKeys, params.length, (String[])paramKeys, new NDList(grads), new NDList(params), -this.priority);
        ++this.priority;
    }

    private static Pointer createdKVStore() {
        return JnaUtils.parameterStoreCreate("device");
    }

    public void close() {
        Pointer pointer = this.handle.getAndSet(null);
        if (pointer != null) {
            JnaUtils.parameterStoreClose(pointer);
        }
    }

    private static final class OptimizerCallback
    implements MxnetLibrary.MXKVStoreStrUpdater {
        private Optimizer optimizer;

        OptimizerCallback(Optimizer optimizer) {
            this.optimizer = optimizer;
        }

        @Override
        public void apply(String parameterId, Pointer recv, Pointer local, Pointer handle) {
            try (NDManager manager = MxNDManager.getSystemManager().newSubManager();){
                MxNDManager m = (MxNDManager)manager;
                MxNDArray grad = m.create(recv);
                MxNDArray weight = m.create(local);
                this.optimizer.update(parameterId, (NDArray)weight, (NDArray)grad);
            }
        }
    }
}

