/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import java.util.Arrays;

public class LocalParameterServer
implements ParameterServer {
    private Optimizer optimizer;

    public LocalParameterServer(Optimizer optimizer) {
        this.optimizer = optimizer;
    }

    @Override
    public void init(String parameterId, NDArray[] value) {
    }

    @Override
    public void update(String parameterId, NDArray[] grads, NDArray[] params) {
        Device firstDevice = params[0].getDevice();
        for (int i = 1; i < grads.length; ++i) {
            NDArray gradCopy = grads[i].toDevice(firstDevice, true);
            Object object = null;
            try {
                grads[0].addi(gradCopy);
                continue;
            }
            catch (Throwable throwable) {
                object = throwable;
                throw throwable;
            }
            finally {
                if (gradCopy != null) {
                    if (object != null) {
                        try {
                            gradCopy.close();
                        }
                        catch (Throwable throwable) {
                            ((Throwable)object).addSuppressed(throwable);
                        }
                    } else {
                        gradCopy.close();
                    }
                }
            }
        }
        try (NDArray aggregatedGrad = grads[0].duplicate();){
            for (NDArray param : params) {
                if (param.getDevice().equals(firstDevice)) {
                    this.optimizer.update(parameterId, param, aggregatedGrad);
                    continue;
                }
                try (NDArray gradSumCopy = aggregatedGrad.toDevice(param.getDevice(), true);){
                    this.optimizer.update(parameterId, param, gradSumCopy);
                }
            }
        }
        Arrays.stream(grads).forEach(NDArray::close);
    }

    @Override
    public void close() {
    }
}

