/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelWrapper {
    private static Logger logger = LoggerFactory.getLogger(ParallelWrapper.class);
    private Model model;
    private int workers = 2;
    private int prefetchSize = 2;
    private int averagingFrequency = 1;
    private Trainer[] zoo;
    private AtomicLong iterationsCounter = new AtomicLong(0L);

    protected ParallelWrapper(Model model, int workers, int prefetchSize) {
        this.model = model;
        this.workers = workers;
        this.prefetchSize = prefetchSize;
        this.zoo = new Trainer[workers];
        for (int cnt = 0; cnt < workers; ++cnt) {
            this.zoo[cnt] = new Trainer(cnt, model);
            this.zoo[cnt].start();
        }
    }

    public synchronized void fit(@NonNull DataSetIterator source) {
        if (source == null) {
            throw new NullPointerException("source");
        }
        DataSetIterator iterator = this.prefetchSize > 0 && !(source instanceof AsyncDataSetIterator) && !(source instanceof ListDataSetIterator) ? new AsyncDataSetIterator(source, this.prefetchSize) : source;
        AtomicInteger locker = new AtomicInteger(0);
        iterator.reset();
        while (iterator.hasNext()) {
            DataSet dataSet = (DataSet)iterator.next();
            int pos = locker.getAndIncrement();
            this.zoo[pos].feedDataSet(dataSet);
            if (pos + 1 != this.workers && iterator.hasNext()) continue;
            this.iterationsCounter.incrementAndGet();
            for (int cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                try {
                    this.zoo[cnt].waitTillRunning();
                    continue;
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            if (this.iterationsCounter.get() % (long)this.averagingFrequency == 0L || !iterator.hasNext()) {
                double score = 0.0;
                INDArray result = Nd4j.zeros((int[])this.model.params().shape());
                for (int cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                    INDArray params = this.zoo[cnt].getModel().params();
                    result.addi(params);
                    score += this.zoo[cnt].getModel().score();
                }
                result.divi((Number)Math.min(this.workers, locker.get()));
                this.model.setParams(result);
                logger.info("Averaged score: " + (score /= (double)Math.min(this.workers, locker.get())));
                if (this.model instanceof MultiLayerNetwork) {
                    UpdaterAggregator uag = ((MultiLayerNetwork)this.zoo[0].getModel()).getUpdater().getAggregator(false);
                    for (int cnt = 0; cnt < this.workers; ++cnt) {
                        uag.merge(((MultiLayerNetwork)this.zoo[cnt].getModel()).getUpdater().getAggregator(true));
                    }
                    ((MultiLayerNetwork)this.model).setScore(score);
                    ((MultiLayerNetwork)this.model).setUpdater(uag.getUpdater());
                } else if (this.model instanceof ComputationGraph) {
                    ComputationGraphUpdater.Aggregator uag = ((ComputationGraph)this.zoo[0].getModel()).getUpdater().getAggregator(false);
                    for (int cnt = 0; cnt < this.workers; ++cnt) {
                        uag.merge(((ComputationGraph)this.zoo[cnt].getModel()).getUpdater().getAggregator(true));
                    }
                    ((ComputationGraph)this.model).setScore(score);
                    ((ComputationGraph)this.model).setUpdater(uag.getUpdater());
                }
                for (int i = 0; i < this.workers; ++i) {
                    this.zoo[i].updateModel(this.model);
                }
            }
            locker.set(0);
        }
    }

    private static class Trainer
    extends Thread
    implements Runnable {
        private Model originalModel;
        private Model replicatedModel;
        private LinkedBlockingQueue<DataSet> queue = new LinkedBlockingQueue();
        private AtomicInteger running = new AtomicInteger(0);
        private int threadId;

        public Trainer(int threadId, Model model) {
            this.threadId = threadId;
            this.setDaemon(true);
            this.originalModel = model;
            if (model instanceof MultiLayerNetwork) {
                this.replicatedModel = ((MultiLayerNetwork)model).clone();
                if (threadId != 0) {
                    ((MultiLayerNetwork)this.replicatedModel).setListeners(new ArrayList<IterationListener>());
                }
            } else if (model instanceof ComputationGraph) {
                this.replicatedModel = ((ComputationGraph)model).clone();
                if (threadId != 0) {
                    ((ComputationGraph)this.replicatedModel).setListeners(new ArrayList<IterationListener>());
                }
            }
        }

        public void feedDataSet(@NonNull DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet");
            }
            this.running.incrementAndGet();
            this.queue.add(dataSet);
        }

        public Model getModel() {
            return this.replicatedModel;
        }

        public void updateModel(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model");
            }
            if (model instanceof MultiLayerNetwork) {
                this.replicatedModel = ((MultiLayerNetwork)model).clone();
            } else if (model instanceof ComputationGraph) {
                this.replicatedModel = ((ComputationGraph)model).clone();
            }
        }

        public boolean isRunning() {
            return this.running.get() == 0;
        }

        @Override
        public void run() {
            try {
                while (true) {
                    DataSet dataSet;
                    if ((dataSet = this.queue.poll(1L, TimeUnit.SECONDS)) == null) {
                        continue;
                    }
                    if (this.replicatedModel instanceof MultiLayerNetwork) {
                        ((MultiLayerNetwork)this.replicatedModel).fit(dataSet);
                    } else if (this.replicatedModel instanceof ComputationGraph) {
                        ((ComputationGraph)this.replicatedModel).fit(dataSet);
                    }
                    this.running.decrementAndGet();
                }
            }
            catch (Exception exception) {
                return;
            }
        }

        public void waitTillRunning() {
            while (this.running.get() != 0) {
                try {
                    Thread.sleep(10L);
                }
                catch (Exception exception) {}
            }
        }
    }

    public static class Builder {
        private Model model;
        private int workers = 2;
        private int prefetchSize = 2;
        private int averagingFrequency = 1;

        public Builder(@NonNull MultiLayerNetwork mln) {
            if (mln == null) {
                throw new NullPointerException("mln");
            }
            this.model = mln;
        }

        public Builder(@NonNull ComputationGraph graph) {
            if (graph == null) {
                throw new NullPointerException("graph");
            }
            this.model = graph;
        }

        public Builder workers(int num) {
            if (num < 1) {
                throw new RuntimeException("Number of workers can't be lower then 1!");
            }
            this.workers = num;
            return this;
        }

        public Builder averagingFrequency(int freq) {
            this.averagingFrequency = freq;
            return this;
        }

        public Builder prefetchBuffer(int size) {
            if (size < 0) {
                size = 0;
            }
            this.prefetchSize = size;
            return this;
        }

        public ParallelWrapper build() {
            ParallelWrapper wrapper = new ParallelWrapper(this.model, this.workers, this.prefetchSize);
            wrapper.averagingFrequency = this.averagingFrequency;
            return wrapper;
        }
    }
}

