/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelWrapper
implements AutoCloseable {
    private static Logger logger = LoggerFactory.getLogger(ParallelWrapper.class);
    private Model model;
    private int workers = 2;
    private int prefetchSize = 2;
    private int averagingFrequency = 1;
    private Trainer[] zoo;
    private AtomicLong iterationsCounter = new AtomicLong(0L);
    private boolean reportScore = false;
    private boolean averageUpdaters = true;
    private boolean legacyAveraging = false;

    protected ParallelWrapper(Model model, int workers, int prefetchSize) {
        this.model = model;
        this.workers = workers;
        this.prefetchSize = prefetchSize;
        if (this.model instanceof MultiLayerNetwork) {
            ((MultiLayerNetwork)this.model).getUpdater();
        } else if (this.model instanceof ComputationGraph) {
            ((ComputationGraph)this.model).getUpdater();
        }
        this.zoo = new Trainer[workers];
        for (int cnt = 0; cnt < workers; ++cnt) {
            this.zoo[cnt] = new Trainer(cnt, model);
            this.zoo[cnt].start();
        }
    }

    @Override
    public void close() throws Exception {
        if (this.zoo != null) {
            for (int i = 0; i < this.zoo.length; ++i) {
                if (this.zoo[i] == null) continue;
                this.zoo[i].shutdown();
            }
            this.zoo = null;
        }
    }

    public synchronized void shutdown() {
        try {
            this.close();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public synchronized void fit(@NonNull MultiDataSetIterator source) {
        int cnt;
        if (source == null) {
            throw new NullPointerException("source");
        }
        if (this.zoo == null) {
            this.zoo = new Trainer[this.workers];
            for (cnt = 0; cnt < this.workers; ++cnt) {
                this.zoo[cnt] = new Trainer(cnt, this.model, true);
                this.zoo[cnt].start();
            }
        } else {
            for (cnt = 0; cnt < this.workers; ++cnt) {
                this.zoo[cnt].useMDS = true;
            }
        }
        source.reset();
        Object iterator = this.prefetchSize > 0 && source.asyncSupported() ? new AsyncMultiDataSetIterator(source, this.prefetchSize) : source;
        AtomicInteger locker = new AtomicInteger(0);
        while (iterator.hasNext()) {
            MultiDataSet dataSet = (MultiDataSet)iterator.next();
            int pos = locker.getAndIncrement();
            this.zoo[pos].feedMultiDataSet(dataSet);
            if (pos + 1 != this.workers && iterator.hasNext()) continue;
            this.iterationsCounter.incrementAndGet();
            for (int cnt2 = 0; cnt2 < this.workers && cnt2 < locker.get(); ++cnt2) {
                try {
                    this.zoo[cnt2].waitTillRunning();
                    continue;
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            if (this.iterationsCounter.get() % (long)this.averagingFrequency == 0L && pos + 1 == this.workers) {
                int cnt3;
                INDArray params;
                double score = 0.0;
                if (!this.legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
                    params = new ArrayList();
                    for (cnt3 = 0; cnt3 < this.workers && cnt3 < locker.get(); ++cnt3) {
                        params.add(this.zoo[cnt3].getModel().params());
                        score += this.zoo[cnt3].getModel().score();
                    }
                    Nd4j.averageAndPropagate((INDArray)this.model.params(), (Collection)params);
                } else {
                    params = Nd4j.zeros((int[])this.model.params().shape());
                    for (cnt3 = 0; cnt3 < this.workers && cnt3 < locker.get(); ++cnt3) {
                        params.addi(this.zoo[cnt3].getModel().params());
                        score += this.zoo[cnt3].getModel().score();
                    }
                    params.divi((Number)cnt3);
                    this.model.setParams(params);
                }
                score /= (double)Math.min(this.workers, locker.get());
                if (this.reportScore) {
                    logger.info("Averaged score: " + score);
                }
                if (this.model instanceof ComputationGraph) {
                    ComputationGraphUpdater updater;
                    if (this.averageUpdaters && (updater = ((ComputationGraph)this.model).getUpdater()) != null && updater.getStateViewArray() != null) {
                        int cnt4;
                        if (!this.legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
                            ArrayList<INDArray> updaters = new ArrayList<INDArray>();
                            for (cnt4 = 0; cnt4 < this.workers && cnt4 < locker.get(); ++cnt4) {
                                updaters.add(((ComputationGraph)this.zoo[cnt4].getModel()).getUpdater().getStateViewArray());
                            }
                            Nd4j.averageAndPropagate((INDArray)updater.getStateViewArray(), updaters);
                        } else {
                            INDArray state = Nd4j.zeros((int[])updater.getStateViewArray().shape());
                            for (cnt4 = 0; cnt4 < this.workers && cnt4 < locker.get(); ++cnt4) {
                                state.addi(((ComputationGraph)this.zoo[cnt4].getModel()).getUpdater().getStateViewArray());
                            }
                            state.divi((Number)cnt4);
                            updater.setStateViewArray(state);
                        }
                    }
                } else {
                    throw new RuntimeException("MultiDataSet might be used only with ComputationGraph model");
                }
                ((ComputationGraph)this.model).setScore(score);
                if (this.legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
                    for (int cnt5 = 0; cnt5 < this.workers; ++cnt5) {
                        this.zoo[cnt5].updateModel(this.model);
                    }
                }
            }
            locker.set(0);
        }
        logger.debug("Iterations passed: {}", (Object)this.iterationsCounter.get());
        this.iterationsCounter.set(0L);
    }

    public synchronized void fit(@NonNull DataSetIterator source) {
        if (source == null) {
            throw new NullPointerException("source");
        }
        if (this.zoo == null) {
            this.zoo = new Trainer[this.workers];
            for (int cnt = 0; cnt < this.workers; ++cnt) {
                this.zoo[cnt] = new Trainer(cnt, this.model);
                this.zoo[cnt].start();
            }
        }
        source.reset();
        Object iterator = this.prefetchSize > 0 && source.asyncSupported() ? new AsyncDataSetIterator(source, this.prefetchSize) : source;
        AtomicInteger locker = new AtomicInteger(0);
        while (iterator.hasNext()) {
            DataSet dataSet = (DataSet)iterator.next();
            int pos = locker.getAndIncrement();
            this.zoo[pos].feedDataSet(dataSet);
            if (pos + 1 != this.workers && iterator.hasNext()) continue;
            this.iterationsCounter.incrementAndGet();
            for (int cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                try {
                    this.zoo[cnt].waitTillRunning();
                    continue;
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            if (this.iterationsCounter.get() % (long)this.averagingFrequency == 0L && pos + 1 == this.workers) {
                int cnt;
                Updater updater;
                int cnt2;
                INDArray params;
                double score = 0.0;
                if (!this.legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
                    params = new ArrayList();
                    for (cnt2 = 0; cnt2 < this.workers && cnt2 < locker.get(); ++cnt2) {
                        params.add(this.zoo[cnt2].getModel().params());
                        score += this.zoo[cnt2].getModel().score();
                    }
                    Nd4j.averageAndPropagate((INDArray)this.model.params(), (Collection)params);
                } else {
                    params = Nd4j.zeros((int[])this.model.params().shape());
                    for (cnt2 = 0; cnt2 < this.workers && cnt2 < locker.get(); ++cnt2) {
                        params.addi(this.zoo[cnt2].getModel().params());
                        score += this.zoo[cnt2].getModel().score();
                    }
                    params.divi((Number)cnt2);
                    this.model.setParams(params);
                }
                score /= (double)Math.min(this.workers, locker.get());
                if (this.reportScore) {
                    logger.info("Averaged score: " + score);
                }
                if (this.model instanceof MultiLayerNetwork) {
                    if (this.averageUpdaters && (updater = ((MultiLayerNetwork)this.model).getUpdater()) != null && updater.getStateViewArray() != null) {
                        if (!this.legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
                            ArrayList<INDArray> updaters = new ArrayList<INDArray>();
                            for (cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                                updaters.add(((MultiLayerNetwork)this.zoo[cnt].getModel()).getUpdater().getStateViewArray());
                            }
                            Nd4j.averageAndPropagate((INDArray)updater.getStateViewArray(), updaters);
                        } else {
                            INDArray state = Nd4j.zeros((int[])updater.getStateViewArray().shape());
                            for (cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                                state.addi(((MultiLayerNetwork)this.zoo[cnt].getModel()).getUpdater().getStateViewArray().dup());
                            }
                            state.divi((Number)cnt);
                            updater.setStateViewArray((Layer)((MultiLayerNetwork)this.model), state, false);
                        }
                    }
                    ((MultiLayerNetwork)this.model).setScore(score);
                } else if (this.model instanceof ComputationGraph) {
                    if (this.averageUpdaters && (updater = ((ComputationGraph)this.model).getUpdater()) != null && updater.getStateViewArray() != null) {
                        if (!this.legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
                            ArrayList<INDArray> updaters = new ArrayList<INDArray>();
                            for (cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                                updaters.add(((ComputationGraph)this.zoo[cnt].getModel()).getUpdater().getStateViewArray());
                            }
                            Nd4j.averageAndPropagate((INDArray)updater.getStateViewArray(), updaters);
                        } else {
                            INDArray state = Nd4j.zeros((int[])updater.getStateViewArray().shape());
                            for (cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                                state.addi(((ComputationGraph)this.zoo[cnt].getModel()).getUpdater().getStateViewArray());
                            }
                            state.divi((Number)cnt);
                            updater.setStateViewArray(state);
                        }
                    }
                    ((ComputationGraph)this.model).setScore(score);
                }
                if (this.legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
                    for (int cnt3 = 0; cnt3 < this.workers; ++cnt3) {
                        this.zoo[cnt3].updateModel(this.model);
                    }
                }
            }
            locker.set(0);
        }
        logger.debug("Iterations passed: {}", (Object)this.iterationsCounter.get());
        this.iterationsCounter.set(0L);
    }

    private static class Trainer
    extends Thread
    implements Runnable {
        private Model originalModel;
        private Model replicatedModel;
        private LinkedBlockingQueue<DataSet> queue = new LinkedBlockingQueue();
        private LinkedBlockingQueue<MultiDataSet> queueMDS = new LinkedBlockingQueue();
        private AtomicInteger running = new AtomicInteger(0);
        private int threadId;
        private AtomicBoolean shouldUpdate = new AtomicBoolean(false);
        private AtomicBoolean shouldStop = new AtomicBoolean(false);
        private Exception thrownException;
        private volatile boolean useMDS = false;

        public Trainer(int threadId, Model model, boolean useMDS) {
            this(threadId, model);
            this.useMDS = useMDS;
        }

        public Trainer(int threadId, Model model) {
            this.threadId = threadId;
            this.setDaemon(true);
            this.setName("ParallelWrapper trainer " + threadId);
            this.originalModel = model;
            if (!(model instanceof MultiLayerNetwork) && model instanceof ComputationGraph) {
                this.replicatedModel = ((ComputationGraph)model).clone();
                if (threadId != 0) {
                    ((ComputationGraph)this.replicatedModel).setListeners(new ArrayList());
                }
            }
        }

        public void feedMultiDataSet(@NonNull MultiDataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet");
            }
            this.running.incrementAndGet();
            this.queueMDS.add(dataSet);
        }

        public void feedDataSet(@NonNull DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet");
            }
            this.running.incrementAndGet();
            this.queue.add(dataSet);
        }

        public Model getModel() {
            return this.replicatedModel;
        }

        public void updateModel(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model");
            }
            this.shouldUpdate.set(true);
            if (this.replicatedModel instanceof MultiLayerNetwork) {
                this.replicatedModel.setParams(model.params().dup());
                Updater updater = ((MultiLayerNetwork)model).getUpdater();
                INDArray view = updater.getStateViewArray();
                if (view != null) {
                    updater = ((MultiLayerNetwork)this.replicatedModel).getUpdater();
                    INDArray viewD = view.dup();
                    if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                        ((GridExecutioner)Nd4j.getExecutioner()).flushQueueBlocking();
                    }
                    updater.setStateViewArray((Layer)((MultiLayerNetwork)this.replicatedModel), viewD, false);
                }
            } else if (this.replicatedModel instanceof ComputationGraph) {
                this.replicatedModel.setParams(model.params().dup());
                ComputationGraphUpdater updater = ((ComputationGraph)model).getUpdater();
                INDArray view = updater.getStateViewArray();
                if (view != null) {
                    INDArray viewD = view.dup();
                    if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                        ((GridExecutioner)Nd4j.getExecutioner()).flushQueueBlocking();
                    }
                    updater = ((ComputationGraph)this.replicatedModel).getUpdater();
                    updater.setStateViewArray(viewD);
                }
            }
            if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                ((GridExecutioner)Nd4j.getExecutioner()).flushQueueBlocking();
            }
        }

        public boolean isRunning() {
            if (this.thrownException != null) {
                throw new RuntimeException(this.thrownException);
            }
            return this.running.get() == 0;
        }

        public void shutdown() {
            this.shouldStop.set(true);
        }

        @Override
        public void run() {
            try {
                DataSet dataSet;
                if (this.originalModel instanceof MultiLayerNetwork) {
                    MultiLayerConfiguration conf = ((MultiLayerNetwork)this.originalModel).getLayerWiseConfigurations().clone();
                    this.replicatedModel = new MultiLayerNetwork(conf);
                    ((MultiLayerNetwork)this.replicatedModel).init();
                } else if (this.originalModel instanceof ComputationGraph) {
                    this.replicatedModel = new ComputationGraph(((ComputationGraph)this.originalModel).getConfiguration().clone());
                    ((ComputationGraph)this.replicatedModel).init();
                }
                if (!this.useMDS) {
                    while (!this.shouldStop.get()) {
                        dataSet = this.queue.poll(100L, TimeUnit.MILLISECONDS);
                        if (dataSet == null) continue;
                        if (this.replicatedModel instanceof MultiLayerNetwork) {
                            ((MultiLayerNetwork)this.replicatedModel).fit(dataSet);
                        } else if (this.replicatedModel instanceof ComputationGraph) {
                            ((ComputationGraph)this.replicatedModel).fit(dataSet);
                        }
                        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                            ((GridExecutioner)Nd4j.getExecutioner()).flushQueueBlocking();
                        }
                        this.running.decrementAndGet();
                    }
                } else {
                    while (!this.shouldStop.get()) {
                        dataSet = this.queueMDS.poll(100L, TimeUnit.MILLISECONDS);
                        if (dataSet == null) continue;
                        if (!(this.replicatedModel instanceof ComputationGraph)) {
                            throw new RuntimeException("MultiDataSet can be fit into ComputationGraph only");
                        }
                        ((ComputationGraph)this.replicatedModel).fit((MultiDataSet)dataSet);
                        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                            ((GridExecutioner)Nd4j.getExecutioner()).flushQueueBlocking();
                        }
                        this.running.decrementAndGet();
                    }
                }
            }
            catch (Exception e) {
                this.thrownException = e;
            }
        }

        public void waitTillRunning() {
            while (this.running.get() != 0) {
                if (this.thrownException != null) {
                    throw new RuntimeException(this.thrownException);
                }
                try {
                    Thread.sleep(10L);
                }
                catch (Exception exception) {}
            }
        }
    }

    public static class Builder {
        private Model model;
        private int workers = 2;
        private int prefetchSize = 16;
        private int averagingFrequency = 1;
        private boolean reportScore = false;
        private boolean averageUpdaters = true;
        private boolean legacyAveraging = true;

        public Builder(@NonNull MultiLayerNetwork mln) {
            if (mln == null) {
                throw new NullPointerException("mln");
            }
            this.model = mln;
        }

        public Builder(@NonNull ComputationGraph graph) {
            if (graph == null) {
                throw new NullPointerException("graph");
            }
            this.model = graph;
        }

        public Builder workers(int num) {
            if (num < 2) {
                throw new RuntimeException("Number of workers can't be lower then 2!");
            }
            this.workers = num;
            return this;
        }

        public Builder averagingFrequency(int freq) {
            this.averagingFrequency = freq;
            return this;
        }

        public Builder averageUpdaters(boolean reallyAverage) {
            this.averageUpdaters = reallyAverage;
            return this;
        }

        public Builder prefetchBuffer(int size) {
            if (size < 0) {
                size = 0;
            }
            this.prefetchSize = size;
            return this;
        }

        public Builder useLegacyAveraging(boolean reallyUse) {
            this.legacyAveraging = reallyUse;
            return this;
        }

        public Builder reportScoreAfterAveraging(boolean reallyReport) {
            this.reportScore = reallyReport;
            return this;
        }

        public ParallelWrapper build() {
            ParallelWrapper wrapper = new ParallelWrapper(this.model, this.workers, this.prefetchSize);
            wrapper.averagingFrequency = this.averagingFrequency;
            wrapper.reportScore = this.reportScore;
            wrapper.averageUpdaters = this.averageUpdaters;
            wrapper.legacyAveraging = this.legacyAveraging;
            return wrapper;
        }
    }
}

