/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class LocalParameterServer
implements ParameterServer {
    private Optimizer optimizer;
    private Map<String, NDArray[]> gradMap;

    public LocalParameterServer(Optimizer optimizer) {
        this.optimizer = optimizer;
        this.gradMap = new ConcurrentHashMap<String, NDArray[]>();
    }

    @Override
    public void init(String parameterId, NDArray[] value) {
    }

    @Override
    public void push(String parameterId, NDArray[] grads, int priority) {
        NDArray[] oldGrads = this.gradMap.put(parameterId, grads);
        if (oldGrads != null) {
            Arrays.stream(oldGrads).forEach(NDArray::close);
        }
    }

    @Override
    public void pull(String parameterId, NDArray[] weights, int priority) {
        NDArray[] grads = this.gradMap.get(parameterId);
        Device firstDevice = grads[0].getDevice();
        for (int i = 1; i < grads.length; ++i) {
            try (NDArray gradCopy = grads[i].toDevice(firstDevice, true);){
                grads[0].addi(gradCopy);
                continue;
            }
        }
        for (NDArray weight : weights) {
            if (weight.getDevice().equals(firstDevice)) {
                this.optimizer.update(parameterId, weight, grads[0]);
                continue;
            }
            try (NDArray gradSumCopy = grads[0].toDevice(weight.getDevice(), true);){
                this.optimizer.update(parameterId, weight, gradSumCopy);
            }
        }
        Arrays.stream(grads).forEach(NDArray::close);
    }

    @Override
    public void close() {
    }
}

