/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.parameterserver.distributed.v2;

import io.reactivex.Flowable;
import io.reactivex.disposables.Disposable;
import io.reactivex.functions.Consumer;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.nd4j.common.primitives.Atomic;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.v2.enums.PropagationMode;
import org.nd4j.parameterserver.distributed.v2.messages.impl.GradientsUpdateMessage;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.handshake.HandshakeResponse;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.params.ModelParametersMessage;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.params.ModelParametersRequest;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.params.UpdaterParametersMessage;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.params.UpdaterParametersRequest;
import org.nd4j.parameterserver.distributed.v2.transport.RestartCallback;
import org.nd4j.parameterserver.distributed.v2.transport.Transport;
import org.nd4j.parameterserver.distributed.v2.transport.UpdaterParametersProvider;
import org.nd4j.parameterserver.distributed.v2.transport.UpdatesHandler;
import org.nd4j.parameterserver.distributed.v2.transport.impl.StaticPortSupplier;
import org.nd4j.parameterserver.distributed.v2.util.AbstractSubscriber;
import org.nd4j.parameterserver.distributed.v2.util.UpdaterParametersHolder;
import org.reactivestreams.Subscriber;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class ModelParameterServer {
    private static final Logger log = LoggerFactory.getLogger(ModelParameterServer.class);
    protected static final ModelParameterServer INSTANCE = new ModelParameterServer();
    private Transport transport;
    private INDArray masterModelParams;
    private INDArray masterUpdaterParams;
    private UpdaterParametersProvider updaterParametersProvider;
    private final BlockingQueue<INDArray> updatesQueue = new LinkedBlockingQueue<INDArray>(4096);
    protected final List<UpdatesHandler> updatesSubscribers = new CopyOnWriteArrayList<UpdatesHandler>();
    protected final List<Subscriber<INDArray>> modelParamsSubsribers = new CopyOnWriteArrayList<Subscriber<INDArray>>();
    protected final List<Subscriber<INDArray>> updaterParamsSubscribers = new CopyOnWriteArrayList<Subscriber<INDArray>>();
    private boolean masterMode;
    protected VoidConfiguration configuration;
    private final AtomicBoolean launchLock = new AtomicBoolean(false);
    private final AtomicBoolean stopLock = new AtomicBoolean(false);
    protected BlockingQueue<INDArray> updatesBacklog = new LinkedBlockingQueue<INDArray>();
    protected final Atomic<UpdaterParametersHolder> updaterParameters = new Atomic();
    protected final ReentrantReadWriteLock updaterParamsLock = new ReentrantReadWriteLock();
    protected final AtomicBoolean gotFinalState = new AtomicBoolean(false);
    private Disposable disposable;
    private AtomicInteger iterationNumber = new AtomicInteger(0);
    private AtomicInteger epochNumber = new AtomicInteger(0);

    protected ModelParameterServer() {
    }

    public static ModelParameterServer getInstance() {
        return INSTANCE;
    }

    protected ModelParameterServer(@NonNull Transport transport) {
        this(transport, false);
        if (transport == null) {
            throw new NullPointerException("transport is marked non-null but is null");
        }
    }

    protected ModelParameterServer(@NonNull Transport transport, boolean isMasterNode) {
        this(VoidConfiguration.builder().portSupplier(new StaticPortSupplier(40123)).streamId(119).build(), transport, isMasterNode);
        if (transport == null) {
            throw new NullPointerException("transport is marked non-null but is null");
        }
    }

    public ModelParameterServer(@NonNull VoidConfiguration configuration, @NonNull Transport transport, boolean isMasterNode) {
        this();
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        if (transport == null) {
            throw new NullPointerException("transport is marked non-null but is null");
        }
        this.configure(configuration, transport, isMasterNode);
    }

    public void configure(@NonNull VoidConfiguration configuration, @NonNull Transport transport, boolean isMasterNode) {
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        if (transport == null) {
            throw new NullPointerException("transport is marked non-null but is null");
        }
        this.transport = transport;
        this.masterMode = isMasterNode;
        this.configuration = configuration;
    }

    public void configure(@NonNull VoidConfiguration configuration, @NonNull Transport transport, @NonNull UpdaterParametersProvider updaterProvider) {
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        if (transport == null) {
            throw new NullPointerException("transport is marked non-null but is null");
        }
        if (updaterProvider == null) {
            throw new NullPointerException("updaterProvider is marked non-null but is null");
        }
        this.transport = transport;
        this.masterMode = false;
        this.configuration = configuration;
        this.updaterParametersProvider = updaterProvider;
    }

    public void addUpdatesSubscriber(@NonNull UpdatesHandler s) {
        if (s == null) {
            throw new NullPointerException("s is marked non-null but is null");
        }
        this.updatesSubscribers.add(s);
    }

    public void addModelParamsSubscriber(@NonNull Subscriber<INDArray> s) {
        if (s == null) {
            throw new NullPointerException("s is marked non-null but is null");
        }
        this.modelParamsSubsribers.add(s);
    }

    public void addUpdaterParamsSubscriber(@NonNull Subscriber<INDArray> s) {
        if (s == null) {
            throw new NullPointerException("s is marked non-null but is null");
        }
        this.updaterParamsSubscribers.add(s);
    }

    public boolean isInitialized() {
        return this.launchLock.get();
    }

    public Pair<Integer, Integer> getStartPosition() {
        return Pair.makePair((Object)this.iterationNumber.get(), (Object)this.epochNumber.get());
    }

    public synchronized void launch() {
        log.info("ModelParameterServer starting");
        if (this.launchLock.get()) {
            return;
        }
        this.configuration.setUnicastControllerPort(this.configuration.getPortSupplier().getPort());
        this.transport.setRestartCallback(new RestartCallback(){

            @Override
            public void call(HandshakeResponse response) {
                try {
                    log.info("Restart callback started...");
                    ModelParametersRequest msg = new ModelParametersRequest();
                    String rootId = ModelParameterServer.this.transport.getRootId();
                    ModelParametersMessage modelParams = (ModelParametersMessage)ModelParameterServer.this.transport.sendMessageBlocking(msg, rootId);
                    INDArray mParams = modelParams.getPayload();
                    ModelParameterServer.this.modelParamsSubsribers.forEach(s -> s.onNext((Object)mParams));
                    ModelParameterServer.this.iterationNumber.set(modelParams.getIterationNumber());
                    ModelParameterServer.this.epochNumber.set(modelParams.getEpochNumber());
                    UpdaterParametersMessage updaterParams = (UpdaterParametersMessage)ModelParameterServer.this.transport.sendMessageBlocking(new UpdaterParametersRequest(), rootId);
                    INDArray uParams = updaterParams.getPayload();
                    if (uParams != null) {
                        ModelParameterServer.this.updaterParamsSubscribers.forEach(s -> s.onNext((Object)uParams));
                        log.debug("Updater parameters propagated...");
                    }
                }
                catch (Exception e) {
                    log.error("RestartCallback processing exception: {}", (Throwable)e);
                    throw new RuntimeException(e);
                }
            }
        });
        this.transport.addRequestConsumer(ModelParametersRequest.class, new Consumer<ModelParametersRequest>(){

            public void accept(ModelParametersRequest modelParametersRequest) throws Exception {
                ModelParametersMessage msg = new ModelParametersMessage(UUID.randomUUID().toString(), ModelParameterServer.this.updatesSubscribers.get(0).getParametersArray());
                msg.setRequestId(modelParametersRequest.getRequestId());
                msg.setIterationNumber(ModelParameterServer.this.iterationNumber.get());
                msg.setEpochNumber(ModelParameterServer.this.epochNumber.get());
                ModelParameterServer.this.transport.sendMessage(msg, modelParametersRequest.getOriginatorId());
            }
        });
        if (this.masterMode) {
            this.addUpdaterParamsSubscriber((Subscriber<INDArray>)new AbstractSubscriber<INDArray>(){

                public void onNext(INDArray array) {
                    if (ModelParameterServer.this.gotFinalState.get()) {
                        return;
                    }
                    try {
                        ModelParameterServer.this.updaterParamsLock.writeLock().lock();
                        ((UpdaterParametersHolder)ModelParameterServer.this.updaterParameters.get()).setParameters(array);
                        ((UpdaterParametersHolder)ModelParameterServer.this.updaterParameters.get()).setTimeReceived(System.currentTimeMillis());
                    }
                    finally {
                        ModelParameterServer.this.updaterParamsLock.writeLock().unlock();
                    }
                }
            });
            this.transport.addRequestConsumer(UpdaterParametersRequest.class, new Consumer<UpdaterParametersRequest>(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                public void accept(UpdaterParametersRequest updaterParametersRequest) throws Exception {
                    if (!ModelParameterServer.this.gotFinalState.get()) {
                        String tId = ModelParameterServer.this.transport.getRandomDownstreamFrom(ModelParameterServer.this.transport.getRootId(), updaterParametersRequest.getOriginatorId());
                        log.debug("Sending UpdaterParameters request to [{}]", (Object)tId);
                        UpdaterParametersMessage updaterParams = (UpdaterParametersMessage)ModelParameterServer.this.transport.sendMessageBlocking(new UpdaterParametersRequest(), tId);
                        INDArray uParams = updaterParams.getPayload();
                        try {
                            ModelParameterServer.this.updaterParamsLock.writeLock().lock();
                            if (ModelParameterServer.this.updaterParameters.get() == null) {
                                ModelParameterServer.this.updaterParameters.set((Serializable)new UpdaterParametersHolder(uParams, System.currentTimeMillis(), false));
                            } else {
                                ((UpdaterParametersHolder)ModelParameterServer.this.updaterParameters.get()).setParameters(uParams);
                            }
                        }
                        finally {
                            ModelParameterServer.this.updaterParamsLock.writeLock().unlock();
                        }
                    }
                    try {
                        ModelParameterServer.this.updaterParamsLock.readLock().lock();
                        log.debug("Trying to send back Updater parameters...");
                        UpdaterParametersMessage msg = new UpdaterParametersMessage(UUID.randomUUID().toString(), ((UpdaterParametersHolder)ModelParameterServer.this.updaterParameters.get()).getParameters());
                        msg.setRequestId(updaterParametersRequest.getRequestId());
                        ModelParameterServer.this.transport.sendMessage(msg, updaterParametersRequest.getOriginatorId());
                    }
                    finally {
                        ModelParameterServer.this.updaterParamsLock.readLock().unlock();
                    }
                }
            });
        } else {
            this.transport.addRequestConsumer(UpdaterParametersRequest.class, new Consumer<UpdaterParametersRequest>(){

                public void accept(UpdaterParametersRequest updaterParametersRequest) throws Exception {
                    log.debug("Trying to send back Updater parameters...");
                    if (ModelParameterServer.this.updaterParametersProvider == null) {
                        log.warn("UpdaterParametersProvider wasn't set!");
                        UpdaterParametersMessage msg = new UpdaterParametersMessage(UUID.randomUUID().toString(), null);
                        msg.setRequestId(updaterParametersRequest.getRequestId());
                        ModelParameterServer.this.transport.sendMessage(msg, updaterParametersRequest.getOriginatorId());
                    } else {
                        UpdaterParametersMessage msg = new UpdaterParametersMessage(UUID.randomUUID().toString(), ModelParameterServer.this.updaterParametersProvider.getUpdaterParameters());
                        msg.setRequestId(updaterParametersRequest.getRequestId());
                        ModelParameterServer.this.transport.sendMessage(msg, updaterParametersRequest.getOriginatorId());
                    }
                }
            });
        }
        this.disposable = Flowable.fromPublisher(this.transport.incomingPublisher()).subscribe(message -> {
            if (message instanceof GradientsUpdateMessage) {
                GradientsUpdateMessage gum = (GradientsUpdateMessage)message;
                if (this.iterationNumber.get() < gum.getIteration()) {
                    this.iterationNumber.set(gum.getIteration());
                }
                if (this.epochNumber.get() < gum.getEpoch()) {
                    this.epochNumber.set(gum.getEpoch());
                }
                if (this.updatesSubscribers.isEmpty()) {
                    this.updatesQueue.add(message.getPayload());
                } else {
                    this.updatesSubscribers.forEach(s -> s.onNext(message.getPayload()));
                }
            } else {
                throw new UnsupportedOperationException("Unknown message received: [" + message.getClass().getCanonicalName() + "]");
            }
        });
        if (this.masterMode) {
            this.transport.launchAsMaster();
        } else {
            this.transport.launch();
        }
        this.stopLock.set(false);
        this.launchLock.set(true);
    }

    public synchronized void shutdown() {
        if (this.stopLock.get()) {
            return;
        }
        this.transport.shutdown();
        this.disposable.dispose();
        this.updaterParamsSubscribers.clear();
        this.modelParamsSubsribers.clear();
        this.updatesSubscribers.clear();
        this.updatesQueue.clear();
        this.launchLock.set(false);
        this.stopLock.set(true);
    }

    public void sendUpdate(@NonNull INDArray array, int iteration, int epoch) {
        if (array == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        try {
            GradientsUpdateMessage msg = new GradientsUpdateMessage(UUID.randomUUID().toString(), array);
            msg.setOriginatorId(this.transport.id());
            msg.setIteration(iteration);
            msg.setEpoch(epoch);
            this.transport.propagateMessage(msg, PropagationMode.BOTH_WAYS);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void sendUpdate(@NonNull INDArray array) {
        if (array == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        this.sendUpdate(array, 0, 0);
    }

    public Collection<INDArray> getUpdates() {
        ArrayList<INDArray> list = new ArrayList<INDArray>();
        this.updatesQueue.drainTo(list);
        return list;
    }

    public Transport getTransport() {
        return this.transport;
    }

    public INDArray getMasterModelParams() {
        return this.masterModelParams;
    }

    public INDArray getMasterUpdaterParams() {
        return this.masterUpdaterParams;
    }
}

