/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterServer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class ParameterStore {
    private NDManager manager;
    private Map<String, ParameterData> parameterMap;
    private Map<Device, Integer> deviceMap;
    private boolean copy;
    private ParameterServer parameterServer;

    public ParameterStore(NDManager manager, boolean copy) {
        this.manager = manager;
        this.copy = copy;
        this.parameterMap = new ConcurrentHashMap<String, ParameterData>();
        this.deviceMap = new ConcurrentHashMap<Device, Integer>();
        this.deviceMap.put(manager.getDevice(), 0);
    }

    public void setParameterServer(ParameterServer parameterServer, Device[] devices) {
        this.parameterServer = parameterServer;
        this.deviceMap.clear();
        for (int i = 0; i < devices.length; ++i) {
            if (this.deviceMap.put(devices[i], i) == null) continue;
            throw new IllegalArgumentException("Duplicated devices are not allowed.");
        }
    }

    public void updateAllParameters() {
        ParameterData data;
        String parameterId;
        int priority = 0;
        for (Map.Entry<String, ParameterData> entry : this.parameterMap.entrySet()) {
            parameterId = entry.getKey();
            data = entry.getValue();
            if (!data.requireGradient()) continue;
            NDArray[] grads = (NDArray[])data.getNDArrays().stream().map(NDArray::getGradient).toArray(NDArray[]::new);
            this.parameterServer.push(parameterId, grads, -priority);
            ++priority;
        }
        priority = 0;
        for (Map.Entry<String, ParameterData> entry : this.parameterMap.entrySet()) {
            parameterId = entry.getKey();
            data = entry.getValue();
            if (!data.requireGradient()) continue;
            NDArray[] values = data.toArray();
            this.parameterServer.pull(parameterId, values, -priority);
            ++priority;
        }
    }

    public NDArray getValue(Parameter parameter, Device device) {
        String parameterId = parameter.getId();
        int index = this.deviceMap.get(device);
        ParameterData data = this.parameterMap.computeIfAbsent(parameterId, k -> new ParameterData(parameter));
        if (data.isEmpty()) {
            NDArray array = parameter.getArray();
            if (this.parameterServer != null) {
                this.parameterServer.init(parameterId, new NDArray[]{array});
                NDArray[] arrays = new NDArray[this.deviceMap.size()];
                for (Map.Entry<Device, Integer> entry : this.deviceMap.entrySet()) {
                    Device dev = entry.getKey();
                    int i = entry.getValue();
                    if (i == index && array.getDevice().equals(dev)) {
                        arrays[i] = array;
                    } else {
                        arrays[i] = array.toDevice(dev, true);
                        arrays[i].attach(this.manager);
                        if (parameter.requireGradient()) {
                            arrays[i].attachGradient();
                        }
                    }
                    data.add(arrays[i]);
                }
            } else {
                if (this.copy || !array.getDevice().equals(device)) {
                    array = array.toDevice(device, true);
                    array.attach(this.manager);
                    if (parameter.requireGradient()) {
                        array.attachGradient();
                    }
                }
                data.add(array);
            }
        }
        return data.get(index);
    }

    public void sync() {
        for (ParameterData data : this.parameterMap.values()) {
            data.sync();
        }
    }

    private final class ParameterData {
        private Parameter parameter;
        private List<NDArray> list;

        private ParameterData(Parameter parameter) {
            this.parameter = parameter;
            this.list = Collections.synchronizedList(new ArrayList());
        }

        private List<NDArray> getNDArrays() {
            return this.list;
        }

        private boolean isEmpty() {
            return this.list.isEmpty();
        }

        private void add(NDArray array) {
            this.list.add(array);
        }

        private NDArray get(int index) {
            return this.list.get(index);
        }

        private NDArray[] toArray() {
            return this.list.toArray(new NDArray[0]);
        }

        private boolean requireGradient() {
            return this.parameter.requireGradient();
        }

        private void sync() {
            NDArray array = this.parameter.getArray();
            Device device = array.getDevice();
            if (!ParameterStore.this.deviceMap.containsKey(device)) {
                this.list.get(0).copyTo(array);
            }
        }
    }
}

