/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.listener.RoutingIterationListener;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.DummyBlockDataSetIterator;
import org.deeplearning4j.datasets.iterator.DummyBlockMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.callbacks.DataSetCallback;
import org.deeplearning4j.datasets.iterator.callbacks.InterleavedDataSetCallback;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.SharedGradient;
import org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.Registerable;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
import org.deeplearning4j.parallelism.factory.DefaultTrainerContext;
import org.deeplearning4j.parallelism.factory.SymmetricTrainerContext;
import org.deeplearning4j.parallelism.factory.TrainerContext;
import org.deeplearning4j.parallelism.trainer.Trainer;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelWrapper
implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(ParallelWrapper.class);
    protected Supplier<INDArray> modelParamsSupplier;
    protected Supplier<INDArray> updaterParamsSupplier;
    protected AtomicBoolean exceptionEncountered;
    protected Throwable exception;
    protected final String uuid = UUID.randomUUID().toString();
    protected Model model;
    protected int workers = 2;
    protected int prefetchSize = 2;
    protected int averagingFrequency = 1;
    protected Trainer[] zoo;
    protected TrainerContext trainerContext;
    protected AtomicLong iterationsCounter = new AtomicLong(0L);
    protected boolean reportScore = false;
    protected boolean averageUpdaters = true;
    protected boolean legacyAveraging = false;
    protected boolean wasAveraged = false;
    protected AtomicBoolean stopFit = new AtomicBoolean(false);
    protected List<TrainingListener> listeners = new ArrayList<TrainingListener>();
    protected StatsStorageRouter storageRouter;
    protected boolean isMQ;
    protected WorkspaceMode workspaceMode;
    protected Object[] trainerContextArgs;
    protected boolean debug = false;
    protected ThreadPoolExecutor executorService;
    protected final AtomicInteger workerCounter = new AtomicInteger(0);
    protected GradientsAccumulator gradientsAccumulator;
    Thread.UncaughtExceptionHandler handler = new Thread.UncaughtExceptionHandler(){

        @Override
        public void uncaughtException(Thread th, Throwable ex) {
            log.error("Uncaught exception: " + ex);
            ex.printStackTrace();
            if (ParallelWrapper.this.exceptionEncountered != null) {
                ParallelWrapper.this.exceptionEncountered.set(true);
                ParallelWrapper.this.exception = ex;
            }
        }
    };

    protected ParallelWrapper(Model model, int workers, int prefetchSize) {
        this.model = model;
        this.workers = workers;
        this.prefetchSize = prefetchSize;
        if (this.model instanceof MultiLayerNetwork) {
            ((MultiLayerNetwork)this.model).getUpdater();
        } else if (this.model instanceof ComputationGraph) {
            ((ComputationGraph)this.model).getUpdater();
        }
    }

    protected void init() {
        this.workerCounter.set(0);
        this.executorService = (ThreadPoolExecutor)Executors.newFixedThreadPool(this.workers, new ThreadFactory(){

            @Override
            public Thread newThread(@NonNull Runnable r) {
                if (r == null) {
                    throw new NullPointerException("r is marked @NonNull but is null");
                }
                Thread t = Executors.defaultThreadFactory().newThread(r);
                int cThread = ParallelWrapper.this.workerCounter.getAndIncrement();
                t.setName("ParallelWrapper training thread " + cThread);
                t.setDaemon(true);
                t.setUncaughtExceptionHandler(ParallelWrapper.this.handler);
                Nd4j.getAffinityManager().attachThreadToDevice(t, Integer.valueOf(cThread % Nd4j.getAffinityManager().getNumberOfDevices()));
                return t;
            }
        });
    }

    @Override
    public void close() throws Exception {
        if (this.zoo != null) {
            for (int i = 0; i < this.zoo.length; ++i) {
                if (this.zoo[i] == null) continue;
                this.zoo[i].shutdown();
            }
            this.zoo = null;
        }
        if (this.executorService != null) {
            this.executorService.shutdown();
            this.executorService = null;
        }
        if (this.gradientsAccumulator != null) {
            this.gradientsAccumulator.reset();
        }
    }

    public synchronized void shutdown() {
        try {
            this.close();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void stopFit() {
        this.stopFit.set(true);
    }

    public synchronized void fit(@NonNull MultiDataSetIterator source) {
        if (source == null) {
            throw new NullPointerException("source is marked @NonNull but is null");
        }
        this.stopFit.set(false);
        this.createZooIfNeccessary(true);
        if (!source.hasNext() && source.resetSupported()) {
            source.reset();
        }
        MultiDataSetIterator iterator = source;
        if (this.prefetchSize > 0 && source.asyncSupported()) {
            if (this.isMQ) {
                if (this.workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0) {
                    log.warn("Number of workers [{}] isn't optimal for available devices [{}]", (Object)this.workers, (Object)Nd4j.getAffinityManager().getNumberOfDevices());
                }
                iterator = new AsyncMultiDataSetIterator(source, this.prefetchSize, new LinkedBlockingQueue(this.prefetchSize * this.workers), true, (DataSetCallback)new InterleavedDataSetCallback(this.prefetchSize * 2));
            } else {
                iterator = new AsyncMultiDataSetIterator(source, this.prefetchSize);
            }
        }
        AtomicInteger locker = new AtomicInteger(0);
        DummyBlockMultiDataSetIterator blockWrapper = new DummyBlockMultiDataSetIterator(iterator);
        long time1 = System.currentTimeMillis();
        while (blockWrapper.hasAnything() && !this.stopFit.get()) {
            int pos;
            INDArray params;
            if (this.modelParamsSupplier != null && (params = (INDArray)this.modelParamsSupplier.get()) != null && this.zoo != null) {
                for (Trainer z : this.zoo) {
                    z.updateModelParams(params);
                }
            }
            if (this.updaterParamsSupplier != null && (params = (INDArray)this.updaterParamsSupplier.get()) != null && this.zoo != null) {
                for (Trainer z : this.zoo) {
                    z.updateUpdaterParams(params);
                }
            }
            MultiDataSet[] dataSets = blockWrapper.next(this.workers);
            long time2 = System.currentTimeMillis();
            if (dataSets == null) {
                throw new ND4JIllegalStateException("You can't have NULL as MultiDataSet");
            }
            locker.set(dataSets.length);
            if (this.gradientsAccumulator != null && this.gradientsAccumulator instanceof Registerable) {
                ((Registerable)this.gradientsAccumulator).registerConsumers(dataSets.length);
            }
            for (pos = 0; pos < dataSets.length; ++pos) {
                this.zoo[pos].feedMultiDataSet(dataSets[pos], time2 - time1);
            }
            this.iterationsCounter.incrementAndGet();
            for (pos = 0; pos < dataSets.length; ++pos) {
                this.zoo[pos].waitTillRunning();
            }
            if (this.zoo[0].averagingRequired() && this.iterationsCounter.get() % (long)this.averagingFrequency == 0L) {
                double score = this.getScore(locker);
                this.averageUpdatersState(locker, score);
            }
            locker.set(0);
            time1 = System.currentTimeMillis();
        }
        if (this.debug) {
            log.info("Stopping everyone...");
        }
        if (this.debug) {
            log.info("Shutting down iterator...");
        }
        if (this.prefetchSize > 0 && source.asyncSupported()) {
            ((AsyncMultiDataSetIterator)iterator).shutdown();
        }
        try {
            this.close();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        if (!this.wasAveraged) {
            log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
        }
        log.debug("Iterations passed: {}", (Object)this.iterationsCounter.get());
    }

    private double getScore(AtomicInteger locker) {
        this.wasAveraged = true;
        double score = 0.0;
        ArrayList<INDArray> params = new ArrayList<INDArray>();
        for (int cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
            params.add(this.zoo[cnt].getModel().params());
            score += this.zoo[cnt].getModel().score();
        }
        Nd4j.averageAndPropagate(null, params);
        score /= (double)Math.min(this.workers, locker.get());
        if (this.reportScore) {
            log.info("Averaged score: " + score);
        }
        return score;
    }

    private void averageUpdatersState(AtomicInteger locker, double score) {
        if (this.model instanceof MultiLayerNetwork) {
            if (this.averageUpdaters) {
                Updater updater = ((MultiLayerNetwork)this.model).getUpdater();
                int batchSize = 0;
                if (updater != null && updater.getStateViewArray() != null) {
                    ArrayList<INDArray> updaters = new ArrayList<INDArray>();
                    for (int cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                        MultiLayerNetwork workerModel = (MultiLayerNetwork)this.zoo[cnt].getModel();
                        updaters.add(workerModel.getUpdater().getStateViewArray());
                        batchSize += workerModel.batchSize();
                    }
                    Nd4j.averageAndPropagate(null, updaters);
                }
            }
            ((MultiLayerNetwork)this.model).setScore(score);
        } else if (this.model instanceof ComputationGraph) {
            if (this.averageUpdaters) {
                ComputationGraphUpdater updater = ((ComputationGraph)this.model).getUpdater();
                int batchSize = 0;
                if (updater != null && updater.getStateViewArray() != null) {
                    ArrayList<INDArray> updaters = new ArrayList<INDArray>();
                    for (int cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                        ComputationGraph workerModel = (ComputationGraph)this.zoo[cnt].getModel();
                        updaters.add(workerModel.getUpdater().getStateViewArray());
                        batchSize += workerModel.batchSize();
                    }
                    Nd4j.averageAndPropagate(null, updaters);
                }
            }
            ((ComputationGraph)this.model).setScore(score);
        }
    }

    public void setListeners(@NonNull Collection<TrainingListener> listeners) {
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        this.setListeners(null, listeners);
    }

    public void setListeners(TrainingListener ... listeners) {
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        this.setListeners(Arrays.asList(listeners));
    }

    public void setListeners(StatsStorageRouter statsStorage, TrainingListener ... listeners) {
        this.setListeners(statsStorage, Arrays.asList(listeners));
    }

    public void setListeners(StatsStorageRouter statsStorage, Collection<? extends TrainingListener> listeners) {
        if (listeners != null) {
            for (TrainingListener trainingListener : listeners) {
                if (!(trainingListener instanceof RoutingIterationListener)) continue;
                RoutingIterationListener rl = (RoutingIterationListener)trainingListener;
                if (statsStorage != null || rl.getStorageRouter() != null) continue;
                log.warn("RoutingIterationListener provided without providing any StatsStorage instance. Iterator may not function without one. Listener: {}", (Object)trainingListener);
            }
            this.listeners.addAll(listeners);
        } else {
            this.listeners.clear();
        }
        this.storageRouter = statsStorage;
    }

    public void broadcastGradients(SharedGradient gradients) {
    }

    public synchronized void fit(@NonNull DataSetIterator source) {
        if (source == null) {
            throw new NullPointerException("source is marked @NonNull but is null");
        }
        log.info("Using workspaceMode {} for training", (Object)this.workspaceMode.name());
        this.stopFit.set(false);
        this.createZooIfNeccessary(false);
        if (!source.hasNext() && source.resetSupported()) {
            source.reset();
        }
        DataSetIterator iterator = source;
        if (this.prefetchSize > 0 && source.asyncSupported()) {
            log.info("Creating asynchronous prefetcher...");
            if (this.isMQ) {
                if (this.workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0) {
                    log.warn("Number of workers [{}] isn't optimal for available devices [{}]", (Object)this.workers, (Object)Nd4j.getAffinityManager().getNumberOfDevices());
                }
                iterator = new AsyncDataSetIterator(source, this.prefetchSize, new LinkedBlockingQueue(this.prefetchSize * this.workers), true, (DataSetCallback)new InterleavedDataSetCallback(this.prefetchSize * 2));
            } else {
                iterator = new AsyncDataSetIterator(source, this.prefetchSize);
            }
        }
        ArrayList nanos = new ArrayList();
        AtomicInteger locker = new AtomicInteger(0);
        long time1 = System.currentTimeMillis();
        log.info("Starting ParallelWrapper training round...");
        long intcnt = 0L;
        DummyBlockDataSetIterator blockWrapper = new DummyBlockDataSetIterator(iterator);
        while (blockWrapper.hasAnything() && !this.stopFit.get()) {
            int pos;
            INDArray params;
            if (this.modelParamsSupplier != null && (params = (INDArray)this.modelParamsSupplier.get()) != null && this.zoo != null) {
                log.info("Updating model parameters...");
                for (Trainer z : this.zoo) {
                    z.updateModelParams(params);
                }
            }
            if (this.updaterParamsSupplier != null && (params = (INDArray)this.updaterParamsSupplier.get()) != null && this.zoo != null) {
                log.info("Updating updater parameters...");
                for (Trainer z : this.zoo) {
                    z.updateUpdaterParams(params);
                }
            }
            ++intcnt;
            DataSet[] dataSets = blockWrapper.next(this.workers);
            long time2 = System.currentTimeMillis();
            long lastEtlTime = time2 - time1;
            if (dataSets == null) {
                throw new ND4JIllegalStateException("You can't have NULL as DataSet");
            }
            if (this.zoo == null) {
                throw new IllegalStateException("ParallelWrapper.shutdown() has been called too early and will fail from this point forward.");
            }
            locker.set(dataSets.length);
            if (this.gradientsAccumulator != null && this.gradientsAccumulator instanceof Registerable) {
                ((Registerable)this.gradientsAccumulator).registerConsumers(dataSets.length);
            }
            for (pos = 0; pos < dataSets.length; ++pos) {
                if (this.debug) {
                    log.info("Feeding dataset {} to worker {}", (Object)intcnt, (Object)pos);
                }
                this.zoo[pos].feedDataSet(dataSets[pos], lastEtlTime);
            }
            this.iterationsCounter.incrementAndGet();
            for (pos = 0; pos < dataSets.length; ++pos) {
                try {
                    this.zoo[pos].waitTillRunning();
                    continue;
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            if (this.iterationsCounter.get() % (long)this.averagingFrequency == 0L && this.zoo[0].averagingRequired()) {
                long timeA1 = System.currentTimeMillis();
                double score = this.getScore(locker);
                this.averageUpdatersState(locker, score);
                long timeA2 = System.currentTimeMillis();
                if (this.reportScore) {
                    log.info("Averaging time: {} ms", (Object)(timeA2 - timeA1));
                }
            }
            time1 = System.currentTimeMillis();
            locker.set(0);
        }
        if (this.debug) {
            log.info("Stopping everyone...");
        }
        for (int cnt = 0; cnt < this.workers; ++cnt) {
            try {
                this.zoo[cnt].waitTillRunning();
                continue;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        if (this.debug) {
            log.info("Shutting down iterator...");
        }
        if (this.prefetchSize > 0 && source.asyncSupported()) {
            ((AsyncDataSetIterator)iterator).shutdown();
        }
        try {
            this.close();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        if (this.debug) {
            log.info("Iterations passed: {}", (Object)this.iterationsCounter.get());
        }
    }

    private void createZooIfNeccessary(boolean useMDS) {
        if (this.zoo == null) {
            this.trainerContext.init(this.model, this.trainerContextArgs);
            this.zoo = new Trainer[this.workers];
            int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
            for (int cnt = 0; cnt < this.workers; ++cnt) {
                this.zoo[cnt] = this.trainerContext.create(this.uuid, cnt, this.model, Nd4j.getAffinityManager().getDeviceForCurrentThread(), useMDS, this, this.workspaceMode, this.averagingFrequency);
                if (this.executorService == null) {
                    this.init();
                }
                this.executorService.execute(this.zoo[cnt]);
            }
        }
    }

    private static TrainingListener cloneListener(TrainingListener original) {
        if (original instanceof RoutingIterationListener) {
            return ((RoutingIterationListener)original).clone();
        }
        return original;
    }

    private void configureListeners(String workerUUID, Collection<TrainingListener> oldListeners, Collection<TrainingListener> replicatedListeners) {
        for (TrainingListener listener : oldListeners) {
            TrainingListener l = ParallelWrapper.cloneListener(listener);
            if (l instanceof RoutingIterationListener) {
                RoutingIterationListener rl = (RoutingIterationListener)l;
                rl.setSessionID(((RoutingIterationListener)listener).getSessionID());
                rl.setWorkerID(workerUUID);
                StatsStorageRouter currentRouter = ((RoutingIterationListener)listener).getStorageRouter();
                if (currentRouter != null) {
                    rl.setStorageRouter(currentRouter);
                } else {
                    rl.setStorageRouter(this.storageRouter);
                }
            }
            replicatedListeners.add(l);
        }
    }

    public Supplier<INDArray> getModelParamsSupplier() {
        return this.modelParamsSupplier;
    }

    public Supplier<INDArray> getUpdaterParamsSupplier() {
        return this.updaterParamsSupplier;
    }

    public AtomicBoolean getExceptionEncountered() {
        return this.exceptionEncountered;
    }

    public Throwable getException() {
        return this.exception;
    }

    public String getUuid() {
        return this.uuid;
    }

    public Model getModel() {
        return this.model;
    }

    public int getWorkers() {
        return this.workers;
    }

    public int getPrefetchSize() {
        return this.prefetchSize;
    }

    public int getAveragingFrequency() {
        return this.averagingFrequency;
    }

    public Trainer[] getZoo() {
        return this.zoo;
    }

    public TrainerContext getTrainerContext() {
        return this.trainerContext;
    }

    public AtomicLong getIterationsCounter() {
        return this.iterationsCounter;
    }

    public boolean isReportScore() {
        return this.reportScore;
    }

    public boolean isAverageUpdaters() {
        return this.averageUpdaters;
    }

    public boolean isLegacyAveraging() {
        return this.legacyAveraging;
    }

    public boolean isWasAveraged() {
        return this.wasAveraged;
    }

    public AtomicBoolean getStopFit() {
        return this.stopFit;
    }

    public List<TrainingListener> getListeners() {
        return this.listeners;
    }

    public StatsStorageRouter getStorageRouter() {
        return this.storageRouter;
    }

    public boolean isMQ() {
        return this.isMQ;
    }

    public WorkspaceMode getWorkspaceMode() {
        return this.workspaceMode;
    }

    public Object[] getTrainerContextArgs() {
        return this.trainerContextArgs;
    }

    public boolean isDebug() {
        return this.debug;
    }

    public ThreadPoolExecutor getExecutorService() {
        return this.executorService;
    }

    public AtomicInteger getWorkerCounter() {
        return this.workerCounter;
    }

    public Thread.UncaughtExceptionHandler getHandler() {
        return this.handler;
    }

    public void setModelParamsSupplier(Supplier<INDArray> modelParamsSupplier) {
        this.modelParamsSupplier = modelParamsSupplier;
    }

    public void setUpdaterParamsSupplier(Supplier<INDArray> updaterParamsSupplier) {
        this.updaterParamsSupplier = updaterParamsSupplier;
    }

    public void setExceptionEncountered(AtomicBoolean exceptionEncountered) {
        this.exceptionEncountered = exceptionEncountered;
    }

    public void setException(Throwable exception) {
        this.exception = exception;
    }

    public void setModel(Model model) {
        this.model = model;
    }

    public void setWorkers(int workers) {
        this.workers = workers;
    }

    public void setPrefetchSize(int prefetchSize) {
        this.prefetchSize = prefetchSize;
    }

    public void setAveragingFrequency(int averagingFrequency) {
        this.averagingFrequency = averagingFrequency;
    }

    public void setZoo(Trainer[] zoo) {
        this.zoo = zoo;
    }

    public void setTrainerContext(TrainerContext trainerContext) {
        this.trainerContext = trainerContext;
    }

    public void setIterationsCounter(AtomicLong iterationsCounter) {
        this.iterationsCounter = iterationsCounter;
    }

    public void setReportScore(boolean reportScore) {
        this.reportScore = reportScore;
    }

    public void setAverageUpdaters(boolean averageUpdaters) {
        this.averageUpdaters = averageUpdaters;
    }

    public void setLegacyAveraging(boolean legacyAveraging) {
        this.legacyAveraging = legacyAveraging;
    }

    public void setWasAveraged(boolean wasAveraged) {
        this.wasAveraged = wasAveraged;
    }

    public void setStopFit(AtomicBoolean stopFit) {
        this.stopFit = stopFit;
    }

    public void setStorageRouter(StatsStorageRouter storageRouter) {
        this.storageRouter = storageRouter;
    }

    public void setMQ(boolean isMQ) {
        this.isMQ = isMQ;
    }

    public void setWorkspaceMode(WorkspaceMode workspaceMode) {
        this.workspaceMode = workspaceMode;
    }

    public void setTrainerContextArgs(Object[] trainerContextArgs) {
        this.trainerContextArgs = trainerContextArgs;
    }

    public void setDebug(boolean debug) {
        this.debug = debug;
    }

    public void setExecutorService(ThreadPoolExecutor executorService) {
        this.executorService = executorService;
    }

    public void setHandler(Thread.UncaughtExceptionHandler handler) {
        this.handler = handler;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ParallelWrapper)) {
            return false;
        }
        ParallelWrapper other = (ParallelWrapper)o;
        if (!other.canEqual(this)) {
            return false;
        }
        Supplier<INDArray> this$modelParamsSupplier = this.getModelParamsSupplier();
        Supplier<INDArray> other$modelParamsSupplier = other.getModelParamsSupplier();
        if (this$modelParamsSupplier == null ? other$modelParamsSupplier != null : !this$modelParamsSupplier.equals(other$modelParamsSupplier)) {
            return false;
        }
        Supplier<INDArray> this$updaterParamsSupplier = this.getUpdaterParamsSupplier();
        Supplier<INDArray> other$updaterParamsSupplier = other.getUpdaterParamsSupplier();
        if (this$updaterParamsSupplier == null ? other$updaterParamsSupplier != null : !this$updaterParamsSupplier.equals(other$updaterParamsSupplier)) {
            return false;
        }
        AtomicBoolean this$exceptionEncountered = this.getExceptionEncountered();
        AtomicBoolean other$exceptionEncountered = other.getExceptionEncountered();
        if (this$exceptionEncountered == null ? other$exceptionEncountered != null : !this$exceptionEncountered.equals(other$exceptionEncountered)) {
            return false;
        }
        Throwable this$exception = this.getException();
        Throwable other$exception = other.getException();
        if (this$exception == null ? other$exception != null : !this$exception.equals(other$exception)) {
            return false;
        }
        String this$uuid = this.getUuid();
        String other$uuid = other.getUuid();
        if (this$uuid == null ? other$uuid != null : !this$uuid.equals(other$uuid)) {
            return false;
        }
        Model this$model = this.getModel();
        Model other$model = other.getModel();
        if (this$model == null ? other$model != null : !this$model.equals(other$model)) {
            return false;
        }
        if (this.getWorkers() != other.getWorkers()) {
            return false;
        }
        if (this.getPrefetchSize() != other.getPrefetchSize()) {
            return false;
        }
        if (this.getAveragingFrequency() != other.getAveragingFrequency()) {
            return false;
        }
        if (!Arrays.deepEquals(this.getZoo(), other.getZoo())) {
            return false;
        }
        TrainerContext this$trainerContext = this.getTrainerContext();
        TrainerContext other$trainerContext = other.getTrainerContext();
        if (this$trainerContext == null ? other$trainerContext != null : !this$trainerContext.equals(other$trainerContext)) {
            return false;
        }
        AtomicLong this$iterationsCounter = this.getIterationsCounter();
        AtomicLong other$iterationsCounter = other.getIterationsCounter();
        if (this$iterationsCounter == null ? other$iterationsCounter != null : !this$iterationsCounter.equals(other$iterationsCounter)) {
            return false;
        }
        if (this.isReportScore() != other.isReportScore()) {
            return false;
        }
        if (this.isAverageUpdaters() != other.isAverageUpdaters()) {
            return false;
        }
        if (this.isLegacyAveraging() != other.isLegacyAveraging()) {
            return false;
        }
        if (this.isWasAveraged() != other.isWasAveraged()) {
            return false;
        }
        AtomicBoolean this$stopFit = this.getStopFit();
        AtomicBoolean other$stopFit = other.getStopFit();
        if (this$stopFit == null ? other$stopFit != null : !this$stopFit.equals(other$stopFit)) {
            return false;
        }
        List<TrainingListener> this$listeners = this.getListeners();
        List<TrainingListener> other$listeners = other.getListeners();
        if (this$listeners == null ? other$listeners != null : !((Object)this$listeners).equals(other$listeners)) {
            return false;
        }
        StatsStorageRouter this$storageRouter = this.getStorageRouter();
        StatsStorageRouter other$storageRouter = other.getStorageRouter();
        if (this$storageRouter == null ? other$storageRouter != null : !this$storageRouter.equals(other$storageRouter)) {
            return false;
        }
        if (this.isMQ() != other.isMQ()) {
            return false;
        }
        WorkspaceMode this$workspaceMode = this.getWorkspaceMode();
        WorkspaceMode other$workspaceMode = other.getWorkspaceMode();
        if (this$workspaceMode == null ? other$workspaceMode != null : !this$workspaceMode.equals(other$workspaceMode)) {
            return false;
        }
        if (!Arrays.deepEquals(this.getTrainerContextArgs(), other.getTrainerContextArgs())) {
            return false;
        }
        if (this.isDebug() != other.isDebug()) {
            return false;
        }
        ThreadPoolExecutor this$executorService = this.getExecutorService();
        ThreadPoolExecutor other$executorService = other.getExecutorService();
        if (this$executorService == null ? other$executorService != null : !this$executorService.equals(other$executorService)) {
            return false;
        }
        AtomicInteger this$workerCounter = this.getWorkerCounter();
        AtomicInteger other$workerCounter = other.getWorkerCounter();
        if (this$workerCounter == null ? other$workerCounter != null : !this$workerCounter.equals(other$workerCounter)) {
            return false;
        }
        GradientsAccumulator this$gradientsAccumulator = this.getGradientsAccumulator();
        GradientsAccumulator other$gradientsAccumulator = other.getGradientsAccumulator();
        if (this$gradientsAccumulator == null ? other$gradientsAccumulator != null : !this$gradientsAccumulator.equals(other$gradientsAccumulator)) {
            return false;
        }
        Thread.UncaughtExceptionHandler this$handler = this.getHandler();
        Thread.UncaughtExceptionHandler other$handler = other.getHandler();
        return !(this$handler == null ? other$handler != null : !this$handler.equals(other$handler));
    }

    protected boolean canEqual(Object other) {
        return other instanceof ParallelWrapper;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Supplier<INDArray> $modelParamsSupplier = this.getModelParamsSupplier();
        result = result * 59 + ($modelParamsSupplier == null ? 43 : $modelParamsSupplier.hashCode());
        Supplier<INDArray> $updaterParamsSupplier = this.getUpdaterParamsSupplier();
        result = result * 59 + ($updaterParamsSupplier == null ? 43 : $updaterParamsSupplier.hashCode());
        AtomicBoolean $exceptionEncountered = this.getExceptionEncountered();
        result = result * 59 + ($exceptionEncountered == null ? 43 : $exceptionEncountered.hashCode());
        Throwable $exception = this.getException();
        result = result * 59 + ($exception == null ? 43 : $exception.hashCode());
        String $uuid = this.getUuid();
        result = result * 59 + ($uuid == null ? 43 : $uuid.hashCode());
        Model $model = this.getModel();
        result = result * 59 + ($model == null ? 43 : $model.hashCode());
        result = result * 59 + this.getWorkers();
        result = result * 59 + this.getPrefetchSize();
        result = result * 59 + this.getAveragingFrequency();
        result = result * 59 + Arrays.deepHashCode(this.getZoo());
        TrainerContext $trainerContext = this.getTrainerContext();
        result = result * 59 + ($trainerContext == null ? 43 : $trainerContext.hashCode());
        AtomicLong $iterationsCounter = this.getIterationsCounter();
        result = result * 59 + ($iterationsCounter == null ? 43 : $iterationsCounter.hashCode());
        result = result * 59 + (this.isReportScore() ? 79 : 97);
        result = result * 59 + (this.isAverageUpdaters() ? 79 : 97);
        result = result * 59 + (this.isLegacyAveraging() ? 79 : 97);
        result = result * 59 + (this.isWasAveraged() ? 79 : 97);
        AtomicBoolean $stopFit = this.getStopFit();
        result = result * 59 + ($stopFit == null ? 43 : $stopFit.hashCode());
        List<TrainingListener> $listeners = this.getListeners();
        result = result * 59 + ($listeners == null ? 43 : ((Object)$listeners).hashCode());
        StatsStorageRouter $storageRouter = this.getStorageRouter();
        result = result * 59 + ($storageRouter == null ? 43 : $storageRouter.hashCode());
        result = result * 59 + (this.isMQ() ? 79 : 97);
        WorkspaceMode $workspaceMode = this.getWorkspaceMode();
        result = result * 59 + ($workspaceMode == null ? 43 : $workspaceMode.hashCode());
        result = result * 59 + Arrays.deepHashCode(this.getTrainerContextArgs());
        result = result * 59 + (this.isDebug() ? 79 : 97);
        ThreadPoolExecutor $executorService = this.getExecutorService();
        result = result * 59 + ($executorService == null ? 43 : $executorService.hashCode());
        AtomicInteger $workerCounter = this.getWorkerCounter();
        result = result * 59 + ($workerCounter == null ? 43 : $workerCounter.hashCode());
        GradientsAccumulator $gradientsAccumulator = this.getGradientsAccumulator();
        result = result * 59 + ($gradientsAccumulator == null ? 43 : $gradientsAccumulator.hashCode());
        Thread.UncaughtExceptionHandler $handler = this.getHandler();
        result = result * 59 + ($handler == null ? 43 : $handler.hashCode());
        return result;
    }

    public String toString() {
        return "ParallelWrapper(modelParamsSupplier=" + this.getModelParamsSupplier() + ", updaterParamsSupplier=" + this.getUpdaterParamsSupplier() + ", exceptionEncountered=" + this.getExceptionEncountered() + ", exception=" + this.getException() + ", uuid=" + this.getUuid() + ", model=" + this.getModel() + ", workers=" + this.getWorkers() + ", prefetchSize=" + this.getPrefetchSize() + ", averagingFrequency=" + this.getAveragingFrequency() + ", zoo=" + Arrays.deepToString(this.getZoo()) + ", trainerContext=" + this.getTrainerContext() + ", iterationsCounter=" + this.getIterationsCounter() + ", reportScore=" + this.isReportScore() + ", averageUpdaters=" + this.isAverageUpdaters() + ", legacyAveraging=" + this.isLegacyAveraging() + ", wasAveraged=" + this.isWasAveraged() + ", stopFit=" + this.getStopFit() + ", listeners=" + this.getListeners() + ", storageRouter=" + this.getStorageRouter() + ", isMQ=" + this.isMQ() + ", workspaceMode=" + this.getWorkspaceMode() + ", trainerContextArgs=" + Arrays.deepToString(this.getTrainerContextArgs()) + ", debug=" + this.isDebug() + ", executorService=" + this.getExecutorService() + ", workerCounter=" + this.getWorkerCounter() + ", gradientsAccumulator=" + this.getGradientsAccumulator() + ", handler=" + this.getHandler() + ")";
    }

    public GradientsAccumulator getGradientsAccumulator() {
        return this.gradientsAccumulator;
    }

    public void setGradientsAccumulator(GradientsAccumulator gradientsAccumulator) {
        this.gradientsAccumulator = gradientsAccumulator;
    }

    public static class Builder<T extends Model> {
        protected TrainingMode trainingMode = TrainingMode.AVERAGING;
        protected T model;
        protected int workers = Nd4j.getAffinityManager().getNumberOfDevices();
        protected int prefetchSize = 16;
        protected int averagingFrequency = 1;
        protected boolean reportScore = false;
        protected boolean averageUpdaters = true;
        protected boolean legacyAveraging = true;
        protected boolean isMQ = Nd4j.getAffinityManager().getNumberOfDevices() > 1;
        protected TrainerContext trainerContext = new DefaultTrainerContext();
        protected Object[] trainerContextArgs;
        protected WorkspaceMode workspaceMode = WorkspaceMode.ENABLED;
        protected Supplier<INDArray> modelParamsSupplier;
        protected Supplier<INDArray> updaterParamsSupplier;
        protected ThresholdAlgorithm thresholdAlgorithm;
        protected ResidualPostProcessor residualPostProcessor;
        protected GradientsAccumulator accumulator;

        public Builder trainerContextArgs(Object ... trainerContextArgs) {
            this.trainerContextArgs = trainerContextArgs;
            return this;
        }

        public Builder trainerFactory(@NonNull TrainerContext trainerContext) {
            if (trainerContext == null) {
                throw new NullPointerException("trainerContext is marked @NonNull but is null");
            }
            this.trainerContext = trainerContext;
            return this;
        }

        public Builder workspaceMode(@NonNull WorkspaceMode mode) {
            if (mode == null) {
                throw new NullPointerException("mode is marked @NonNull but is null");
            }
            this.workspaceMode = mode;
            return this;
        }

        public Builder modelParamsSupplier(Supplier<INDArray> supplier) {
            this.modelParamsSupplier = supplier;
            return this;
        }

        public Builder updaterParamsSupplier(Supplier<INDArray> supplier) {
            this.updaterParamsSupplier = supplier;
            return this;
        }

        public Builder(@NonNull T model) {
            if (model == null) {
                throw new NullPointerException("model is marked @NonNull but is null");
            }
            this.model = model;
        }

        public Builder workers(int num) {
            if (num < 2) {
                throw new RuntimeException("Number of workers can't be lower then 2!");
            }
            this.workers = num;
            return this;
        }

        public Builder averagingFrequency(int freq) {
            if (freq < 0) {
                freq = 0;
            }
            this.averagingFrequency = freq;
            return this;
        }

        public Builder averageUpdaters(boolean reallyAverage) {
            this.averageUpdaters = reallyAverage;
            return this;
        }

        public Builder prefetchBuffer(int size) {
            if (size < 0) {
                size = 0;
            }
            this.prefetchSize = size;
            return this;
        }

        public Builder trainingMode(@NonNull TrainingMode mode) {
            if (mode == null) {
                throw new NullPointerException("mode is marked @NonNull but is null");
            }
            this.trainingMode = mode;
            return this;
        }

        public Builder gradientsAccumulator(@NonNull GradientsAccumulator accumulator) {
            if (accumulator == null) {
                throw new NullPointerException("accumulator is marked @NonNull but is null");
            }
            this.accumulator = accumulator;
            return this;
        }

        public Builder reportScoreAfterAveraging(boolean reallyReport) {
            this.reportScore = reallyReport;
            return this;
        }

        public Builder thresholdAlgorithm(ThresholdAlgorithm thresholdAlgorithm) {
            this.thresholdAlgorithm = thresholdAlgorithm;
            return this;
        }

        public Builder residualPostProcessor(ResidualPostProcessor residualPostProcessor) {
            this.residualPostProcessor = residualPostProcessor;
            return this;
        }

        public ParallelWrapper build() {
            ParallelWrapper wrapper = new ParallelWrapper((Model)this.model, this.workers, this.prefetchSize);
            wrapper.averagingFrequency = this.averagingFrequency;
            wrapper.reportScore = this.reportScore;
            wrapper.averageUpdaters = this.averageUpdaters;
            wrapper.legacyAveraging = this.legacyAveraging;
            wrapper.isMQ = this.isMQ;
            wrapper.workspaceMode = this.workspaceMode;
            wrapper.modelParamsSupplier = this.modelParamsSupplier;
            wrapper.updaterParamsSupplier = this.updaterParamsSupplier;
            switch (this.trainingMode) {
                case AVERAGING: {
                    this.trainerContext = new DefaultTrainerContext();
                    this.accumulator = null;
                    log.info("Creating new AveragingTraining instance");
                    break;
                }
                case SHARED_GRADIENTS: {
                    Preconditions.checkState((this.thresholdAlgorithm != null ? 1 : 0) != 0, (String)"Cannot use SHARED_GRADIENTS training mode without setting a threshold algorithm");
                    this.trainerContext = new SymmetricTrainerContext();
                    if (this.accumulator != null) break;
                    log.info("Creating new GradientsAccumulator instance with threshold of [5e-4");
                    this.accumulator = new EncodedGradientsAccumulator(this.workers, this.thresholdAlgorithm, this.residualPostProcessor, false);
                    break;
                }
                case CUSTOM: {
                    this.trainerContext = new SymmetricTrainerContext();
                    if (this.accumulator != null) break;
                    throw new DL4JInvalidConfigException("Please specify GradientsAccumulator fo encoded gradients mode");
                }
                default: {
                    throw new UnsupportedOperationException("Unknown trainingMode: [" + (Object)((Object)this.trainingMode) + "]");
                }
            }
            wrapper.trainerContext = this.trainerContext;
            wrapper.gradientsAccumulator = this.accumulator;
            wrapper.init();
            ArrayList<TrainingListener> modelListeners = null;
            if (this.model instanceof MultiLayerNetwork) {
                modelListeners = new ArrayList<TrainingListener>(((MultiLayerNetwork)this.model).getListeners());
                this.model.setListeners(Collections.emptyList());
            } else if (this.model instanceof ComputationGraph) {
                modelListeners = new ArrayList(((ComputationGraph)this.model).getListeners());
                this.model.setListeners(Collections.emptyList());
            }
            if (modelListeners != null && !modelListeners.isEmpty()) {
                wrapper.setListeners(modelListeners);
            }
            return wrapper;
        }
    }

    public static enum TrainingMode {
        AVERAGING,
        SHARED_GRADIENTS,
        CUSTOM;

    }
}

