/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.parameterserver.pw;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.Loader;
import org.deeplearning4j.core.storage.Persistable;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.core.storage.StorageMetaData;
import org.deeplearning4j.core.storage.listener.RoutingIterationListener;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.BaseMultiLayerUpdater;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.SleepyTrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.MessageHandler;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.spark.parameterserver.conf.SharedTrainingConfiguration;
import org.deeplearning4j.spark.parameterserver.iterators.VirtualDataSetIterator;
import org.deeplearning4j.spark.parameterserver.iterators.VirtualIterator;
import org.deeplearning4j.spark.parameterserver.iterators.VirtualMultiDataSetIterator;
import org.deeplearning4j.spark.parameterserver.networking.v2.ModelParamsConsumer;
import org.deeplearning4j.spark.parameterserver.networking.v2.UpdaterParamsConsumer;
import org.deeplearning4j.spark.parameterserver.networking.v2.UpdatesConsumer;
import org.deeplearning4j.spark.parameterserver.networking.v2.WiredEncodingHandler;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingResult;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingWorker;
import org.deeplearning4j.spark.parameterserver.util.BlockingObserver;
import org.deeplearning4j.spark.parameterserver.util.CountingIterator;
import org.deeplearning4j.spark.util.SparkUtils;
import org.nd4j.common.function.Supplier;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.enums.TransportType;
import org.nd4j.parameterserver.distributed.util.NetworkOrganizer;
import org.nd4j.parameterserver.distributed.v2.ModelParameterServer;
import org.nd4j.parameterserver.distributed.v2.transport.Transport;
import org.nd4j.parameterserver.distributed.v2.transport.UpdaterParametersProvider;
import org.nd4j.parameterserver.distributed.v2.transport.UpdatesHandler;
import org.nd4j.parameterserver.distributed.v2.transport.impl.AeronUdpTransport;
import org.reactivestreams.Subscriber;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SharedTrainingWrapper {
    private static final Logger log = LoggerFactory.getLogger(SharedTrainingWrapper.class);
    private static SharedTrainingWrapper INSTANCE = new SharedTrainingWrapper();
    private static AtomicLong LAST_INSTANCE_ID = new AtomicLong(Long.MIN_VALUE);
    protected ParallelWrapper wrapper;
    protected VirtualDataSetIterator iteratorDS;
    protected VirtualMultiDataSetIterator iteratorMDS;
    protected List<Iterator<DataSet>> iteratorsDS;
    protected List<Iterator<MultiDataSet>> iteratorsMDS;
    protected AtomicBoolean isFirst = new AtomicBoolean(false);
    protected AtomicBoolean exceptionEncountered = new AtomicBoolean(false);
    protected Throwable exception;
    protected ThreadLocal<AtomicInteger> iteratorDataSetCount = new ThreadLocal();
    protected ThreadLocal<BlockingObserver> observer = new ThreadLocal();
    protected EncodedGradientsAccumulator accumulator;
    protected Model originalModel;
    protected UpdatesConsumer consumer;

    protected SharedTrainingWrapper() {
        this.init();
    }

    protected void init() {
        this.iteratorsDS = new CopyOnWriteArrayList<Iterator<DataSet>>();
        this.iteratorsMDS = new CopyOnWriteArrayList<Iterator<MultiDataSet>>();
        this.iteratorDS = new VirtualDataSetIterator(this.iteratorsDS);
        this.iteratorMDS = new VirtualMultiDataSetIterator(this.iteratorsMDS);
    }

    public static synchronized SharedTrainingWrapper getInstance(long id) {
        if (LAST_INSTANCE_ID.get() != Long.MIN_VALUE && LAST_INSTANCE_ID.get() != id) {
            log.debug("Shutting down existing SharedTrainingWrapper instances; resetting state - previous instance ID {}, new instance ID {}", (Object)LAST_INSTANCE_ID.get(), (Object)id);
            if (SharedTrainingWrapper.INSTANCE.wrapper != null) {
                SharedTrainingWrapper.INSTANCE.wrapper.shutdown();
                SharedTrainingWrapper.INSTANCE.wrapper = null;
            }
            SharedTrainingWrapper.INSTANCE.iteratorsDS.clear();
            SharedTrainingWrapper.INSTANCE.iteratorsMDS.clear();
            SharedTrainingWrapper.INSTANCE.exceptionEncountered.set(false);
            SharedTrainingWrapper.INSTANCE.iteratorDataSetCount = new ThreadLocal();
            SharedTrainingWrapper.INSTANCE.accumulator = null;
            SharedTrainingWrapper.INSTANCE.originalModel = null;
            SharedTrainingWrapper.INSTANCE.consumer = null;
            LAST_INSTANCE_ID.set(id);
        }
        if (LAST_INSTANCE_ID.get() == Long.MIN_VALUE) {
            LAST_INSTANCE_ID.set(id);
        }
        return INSTANCE;
    }

    public void attachDS(Iterator<DataSet> iterator) {
        log.debug("Attaching thread...");
        if (this.iteratorDataSetCount.get() == null) {
            this.iteratorDataSetCount.set(new AtomicInteger(0));
        }
        AtomicInteger count = this.iteratorDataSetCount.get();
        count.set(0);
        VirtualIterator<DataSet> wrapped = new VirtualIterator<DataSet>(new CountingIterator<DataSet>(iterator, count));
        BlockingObserver obs = new BlockingObserver(this.exceptionEncountered);
        wrapped.addObserver(obs);
        this.iteratorsDS.add(wrapped);
        this.observer.set(obs);
    }

    public void attachMDS(Iterator<MultiDataSet> iterator) {
        log.debug("Attaching thread...");
        if (this.iteratorDataSetCount.get() == null) {
            this.iteratorDataSetCount.set(new AtomicInteger(0));
        }
        AtomicInteger count = this.iteratorDataSetCount.get();
        count.set(0);
        VirtualIterator<MultiDataSet> wrapped = new VirtualIterator<MultiDataSet>(new CountingIterator<MultiDataSet>(iterator, count));
        BlockingObserver obs = new BlockingObserver(this.exceptionEncountered);
        wrapped.addObserver(obs);
        this.iteratorsMDS.add(wrapped);
        this.observer.set(obs);
    }

    public SharedTrainingResult run(SharedTrainingWorker worker) {
        if (this.isFirst.compareAndSet(false, true)) {
            int numWorkers;
            this.exceptionEncountered.set(false);
            this.exception = null;
            SharedTrainingConfiguration trainingConfiguration = (SharedTrainingConfiguration)worker.getBroadcastConfiguration().getValue();
            VoidConfiguration voidConfiguration = ((SharedTrainingConfiguration)worker.getBroadcastConfiguration().getValue()).getVoidConfiguration();
            MultiLayerNetwork model = null;
            int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
            int numCores = Loader.totalCores();
            int n = trainingConfiguration.getNumberOfWorkersPerNode() > 0 ? trainingConfiguration.getNumberOfWorkersPerNode() : (numWorkers = numDevices > 1 ? numDevices : Math.min(6, Math.max(1, numCores / 4)));
            if (numDevices > 1 && numWorkers > numDevices) {
                log.warn("WARNING! Using more workers then number of available computational devices!");
            }
            if (this.wrapper == null) {
                log.debug("Starting ParallelWrapper at thread {}", (Object)Thread.currentThread().getId());
                model = worker.getInitialModel();
                if (model == null) {
                    model = worker.getInitialModelGraph();
                }
                if (model == null) {
                    throw new DL4JInvalidConfigException("No model was defined for training");
                }
                List<TrainingListener> listeners = worker.getListeners();
                if (listeners != null) {
                    model.setListeners(listeners);
                    StatsStorageRouter r = worker.getRouter();
                    if (r != null) {
                        for (TrainingListener l : listeners) {
                            if (!(l instanceof RoutingIterationListener)) continue;
                            ((RoutingIterationListener)l).setStorageRouter(r);
                        }
                    }
                }
                WiredEncodingHandler handler = new WiredEncodingHandler(trainingConfiguration.getThresholdAlgorithm(), trainingConfiguration.getResidualPostProcessor(), null, trainingConfiguration.isEncodingDebugMode());
                ModelParamsConsumer modelParamsSupplier = new ModelParamsConsumer();
                UpdaterParamsConsumer updateParamsSupplier = new UpdaterParamsConsumer();
                if (this.accumulator == null) {
                    int queueSize = numWorkers * 2;
                    long bufferSize = trainingConfiguration.getBufferSize() > 0 ? (long)trainingConfiguration.getBufferSize() : EncodedGradientsAccumulator.getOptimalBufferSize((Model)model, (int)numWorkers, (int)2);
                    this.accumulator = new EncodedGradientsAccumulator.Builder(numWorkers).messageHandler((MessageHandler)handler).thresholdAlgorithm(trainingConfiguration.getThresholdAlgorithm()).residualPostProcessor(trainingConfiguration.getResidualPostProcessor()).memoryParameters(bufferSize, queueSize).encodingDebugMode(trainingConfiguration.isEncodingDebugMode()).build();
                    String localIP = null;
                    if (localIP == null && voidConfiguration.getNetworkMask() != null) {
                        NetworkOrganizer organizer = new NetworkOrganizer(voidConfiguration.getNetworkMask());
                        localIP = organizer.getMatchingAddress();
                    }
                    if (localIP == null) {
                        localIP = System.getenv("DL4J_VOID_IP");
                    }
                    if (localIP == null) {
                        localIP = "127.0.0.1";
                        log.warn("Can't get IP address to start VoidParameterServer client. Using localhost instead");
                    }
                    log.debug("Checking for ModelParameterServer existence");
                    this.originalModel = model;
                    if (!ModelParameterServer.getInstance().isInitialized()) {
                        AeronUdpTransport transport;
                        log.info("Initializing transport [{}:{}] with root as [{}:{}]...", new Object[]{localIP, voidConfiguration.getPortSupplier().getPort(), voidConfiguration.getControllerAddress(), voidConfiguration.getUnicastControllerPort()});
                        AeronUdpTransport aeronUdpTransport = transport = voidConfiguration.getTransportType() == TransportType.ROUTED_UDP ? new AeronUdpTransport(localIP, voidConfiguration.getPortSupplier().getPort(), voidConfiguration.getControllerAddress(), voidConfiguration.getUnicastControllerPort(), voidConfiguration) : null;
                        if (transport == null) {
                            throw new DL4JInvalidConfigException("No Transport implementation was defined for this training session!");
                        }
                        this.consumer = UpdatesConsumer.builder().numWorkers(numWorkers).accumulator((GradientsAccumulator)this.accumulator).params(model.params()).build();
                        this.accumulator.setExternalSource(this.consumer.getUpdatesQueue());
                        log.debug("Configuring transport...");
                        ModelParameterServer.getInstance().configure(voidConfiguration, (Transport)transport, new UpdaterParametersProvider(){

                            public INDArray getUpdaterParameters() {
                                log.info("Serving updater parameters...");
                                Updater updater = null;
                                if (SharedTrainingWrapper.this.originalModel instanceof MultiLayerNetwork) {
                                    updater = ((MultiLayerNetwork)SharedTrainingWrapper.this.originalModel).getUpdater();
                                } else if (SharedTrainingWrapper.this.originalModel instanceof ComputationGraph) {
                                    updater = ((ComputationGraph)SharedTrainingWrapper.this.originalModel).getUpdater();
                                }
                                if (updater != null) {
                                    if (updater instanceof BaseMultiLayerUpdater) {
                                        return ((BaseMultiLayerUpdater)updater).getStateViewArrayCopy();
                                    }
                                    log.error("Updater doesn't implement getStateViewArrayCopy()");
                                    return null;
                                }
                                log.warn("No Updater in the model");
                                return null;
                            }
                        });
                        ModelParameterServer.getInstance().addUpdatesSubscriber((UpdatesHandler)this.consumer);
                        ModelParameterServer.getInstance().addModelParamsSubscriber((Subscriber)modelParamsSupplier);
                        ModelParameterServer.getInstance().addUpdaterParamsSubscriber((Subscriber)updateParamsSupplier);
                    }
                    log.debug("Starting ModelParameterServer...");
                    ModelParameterServer.getInstance().launch();
                    while (!ModelParameterServer.getInstance().getTransport().isIntroduced()) {
                        try {
                            Thread.sleep(100L);
                        }
                        catch (InterruptedException e) {
                            throw new RuntimeException(e);
                        }
                    }
                }
                if (this.originalModel instanceof MultiLayerNetwork) {
                    model.setIterationCount(((Integer)ModelParameterServer.getInstance().getStartPosition().getFirst()).intValue());
                    model.setEpochCount(((Integer)ModelParameterServer.getInstance().getStartPosition().getSecond()).intValue());
                } else if (this.originalModel instanceof ComputationGraph) {
                    ((ComputationGraph)model).getConfiguration().setIterationCount(((Integer)ModelParameterServer.getInstance().getStartPosition().getFirst()).intValue());
                    ((ComputationGraph)model).getConfiguration().setEpochCount(((Integer)ModelParameterServer.getInstance().getStartPosition().getSecond()).intValue());
                }
                if (trainingConfiguration.getDebugLongerIterations() > 0L) {
                    log.warn("Adding SleepyListener: {} ms", (Object)trainingConfiguration.getDebugLongerIterations());
                    model.addListeners(new TrainingListener[]{SleepyTrainingListener.builder().timerIteration(trainingConfiguration.getDebugLongerIterations()).build()});
                }
                this.accumulator.markExternalUpdates(true);
                if (numWorkers > 1) {
                    this.wrapper = new ParallelWrapper.Builder(this.originalModel).workers(numWorkers).workspaceMode(trainingConfiguration.getWorkspaceMode()).trainingMode(ParallelWrapper.TrainingMode.CUSTOM).gradientsAccumulator((GradientsAccumulator)this.accumulator).prefetchBuffer(trainingConfiguration.getPrefetchSize()).modelParamsSupplier((Supplier)modelParamsSupplier).updaterParamsSupplier((Supplier)updateParamsSupplier).thresholdAlgorithm(trainingConfiguration.getThresholdAlgorithm()).residualPostProcessor(trainingConfiguration.getResidualPostProcessor()).build();
                    this.wrapper.setExceptionEncountered(this.exceptionEncountered);
                } else {
                    log.debug("Using standalone model instead...");
                    this.accumulator.fallbackToSingleConsumerMode(true);
                    this.accumulator.touch();
                    INDArray mParams = modelParamsSupplier.get();
                    if (mParams != null) {
                        log.info("Updating model params to the most recent ones...");
                        this.originalModel.params().assign(mParams);
                    }
                    if (model instanceof ComputationGraph) {
                        ((ComputationGraph)this.originalModel).getConfiguration().setTrainingWorkspaceMode(trainingConfiguration.getWorkspaceMode());
                        ((ComputationGraph)this.originalModel).setGradientsAccumulator((GradientsAccumulator)this.accumulator);
                    } else if (model instanceof MultiLayerNetwork) {
                        ((MultiLayerNetwork)this.originalModel).getLayerWiseConfigurations().setTrainingWorkspaceMode(trainingConfiguration.getWorkspaceMode());
                        ((MultiLayerNetwork)this.originalModel).setGradientsAccumulator((GradientsAccumulator)this.accumulator);
                    }
                }
            }
            if (this.consumer != null) {
                this.consumer.bypassMode(false);
            }
            if (this.iteratorDS == null && this.iteratorMDS == null) {
                throw new DL4JInvalidConfigException("No iterators were defined for training");
            }
            try {
                boolean mdsNext;
                boolean dsNext;
                while ((dsNext = this.iteratorDS != null && this.iteratorDS.hasNext()) || (mdsNext = this.iteratorMDS != null && this.iteratorMDS.hasNext())) {
                    if (this.wrapper != null) {
                        if (dsNext) {
                            this.wrapper.fit((DataSetIterator)this.iteratorDS);
                        } else {
                            this.wrapper.fit((MultiDataSetIterator)this.iteratorMDS);
                        }
                    } else if (dsNext) {
                        if (model instanceof ComputationGraph) {
                            ((ComputationGraph)this.originalModel).fit((DataSetIterator)this.iteratorDS);
                        } else if (model instanceof MultiLayerNetwork) {
                            ((MultiLayerNetwork)this.originalModel).fit((DataSetIterator)this.iteratorDS);
                        }
                    } else if (model instanceof ComputationGraph) {
                        ((ComputationGraph)this.originalModel).fit((MultiDataSetIterator)this.iteratorMDS);
                    } else if (model instanceof MultiLayerNetwork) {
                        ((MultiLayerNetwork)this.originalModel).fit((MultiDataSetIterator)this.iteratorMDS);
                    }
                    if (this.consumer == null) continue;
                    this.consumer.getUpdatesQueue().purge();
                }
            }
            catch (Throwable t) {
                log.warn("Exception encountered during fit operation", t);
                this.exceptionEncountered.set(true);
                this.exception = t;
            }
            EncodedGradientsAccumulator accum = this.wrapper != null ? (EncodedGradientsAccumulator)this.wrapper.getGradientsAccumulator() : this.accumulator;
            if (trainingConfiguration.isEpochReset()) {
                this.wrapper.shutdown();
                this.wrapper = null;
            }
            this.init();
            this.accumulator.reset();
            if (this.consumer != null) {
                this.consumer.bypassMode(true);
            }
            this.isFirst.set(false);
            log.info("Master thread done...");
            INDArray updaterState = null;
            if (model instanceof ComputationGraph) {
                updaterState = ((ComputationGraph)this.originalModel).getUpdater().getUpdaterStateViewArray();
            } else if (model instanceof MultiLayerNetwork) {
                updaterState = ((MultiLayerNetwork)this.originalModel).getUpdater().getStateViewArray();
            }
            EncodingHandler mh = (EncodingHandler)accum.getHandler();
            ThresholdAlgorithm taAveraged = mh.getAverageThresholdAlgorithm();
            SharedTrainingResult result = SharedTrainingResult.builder().aggregationsCount(1).scoreSum(this.originalModel.score()).updaterStateArray(updaterState).listenerMetaData(new ArrayList<StorageMetaData>()).listenerStaticInfo(new ArrayList<Persistable>()).listenerUpdates(new ArrayList<Persistable>()).minibatchesPerExecutor(Collections.singletonMap(SparkUtils.getSparkExecutorId(), this.iteratorDataSetCount.get().get())).thresholdAlgorithm(taAveraged).build();
            return result;
        }
        try {
            this.observer.get().waitTillDone();
            log.info("Feeder [{}] thread done...", (Object)Thread.currentThread().getName());
            if (this.exceptionEncountered.get()) {
                Throwable t = this.wrapper == null || this.exception != null ? this.exception : this.wrapper.getException();
                throw new RuntimeException("Training failed due to exception in ParallelWrapper fit operation", t);
            }
            return SharedTrainingResult.builder().minibatchesPerExecutor(Collections.singletonMap(SparkUtils.getSparkExecutorId(), this.iteratorDataSetCount.get().get())).build();
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
    }

    public void passDataSet(DataSet dataSet) {
    }

    public void passDataSet(MultiDataSet dataSet) {
    }

    public void blockUntilFinished() throws InterruptedException {
        if (this.observer.get() == null) {
            throw new IllegalStateException("This method can't be called before iterators initialization");
        }
        this.observer.get().wait();
    }
}

