/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.paragraphvectors;

import akka.actor.ActorSystem;
import com.google.common.base.Function;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import javax.annotation.Nullable;
import org.deeplearning4j.bagofwords.vectorizer.TextVectorizer;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.parallel.Parallelization;
import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.UimaTokenizerFactory;

public class ParagraphVectors
extends Word2Vec {
    protected Queue<LinkedList<Pair<List<VocabWord>, Collection<VocabWord>>>> jobQueue = new LinkedBlockingDeque<LinkedList<Pair<List<VocabWord>, Collection<VocabWord>>>>(10000);

    @Override
    public void fit() throws IOException {
        boolean loaded = this.buildVocab();
        if (!loaded && this.saveVocab) {
            this.vocab().saveVocab();
        }
        if (this.stopWords == null) {
            this.readStopWords();
        }
        log.info("Training word2vec multithreaded");
        if (this.sentenceIter != null) {
            this.sentenceIter.reset();
        }
        if (this.docIter != null) {
            this.docIter.reset();
        }
        this.totalWords = this.vectorizer.numWordsEncountered();
        this.totalWords *= (long)this.numIterations;
        log.info("Processing sentences...");
        AtomicLong numWordsSoFar = new AtomicLong(0L);
        final AtomicLong nextRandom = new AtomicLong(5L);
        final AtomicInteger doc = new AtomicInteger(0);
        ThreadPoolExecutor exec = new ThreadPoolExecutor(Runtime.getRuntime().availableProcessors(), Runtime.getRuntime().availableProcessors(), 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>(), new RejectedExecutionHandler(){

            @Override
            public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
                try {
                    Thread.sleep(1000L);
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
                executor.submit(r);
            }
        });
        final ConcurrentLinkedDeque<Pair<List<VocabWord>, Collection<VocabWord>>> batch2 = new ConcurrentLinkedDeque<Pair<List<VocabWord>, Collection<VocabWord>>>();
        int[] docs = this.vectorizer.index().allDocs();
        if (docs.length < 1) {
            throw new IllegalStateException("No documents found");
        }
        this.vectorizer.index().eachDocWithLabels(new Function<Pair<List<VocabWord>, Collection<String>>, Void>(){

            public Void apply(@Nullable Pair<List<VocabWord>, Collection<String>> input) {
                ArrayList<VocabWord> batch = new ArrayList<VocabWord>();
                ParagraphVectors.this.addWords((List)input.getFirst(), nextRandom, batch);
                if (batch.isEmpty()) {
                    return null;
                }
                ArrayList<VocabWord> docLabels = new ArrayList<VocabWord>();
                for (String s : (Collection)input.getSecond()) {
                    docLabels.add(ParagraphVectors.this.vocab().wordFor(s));
                }
                batch2.add(new Pair(batch, docLabels));
                doc.incrementAndGet();
                if (doc.get() > 0 && doc.get() % 10000 == 0) {
                    log.info("Doc " + doc.get() + " done so far");
                }
                return null;
            }
        }, exec);
        if (!batch2.isEmpty()) {
            this.jobQueue.add(new LinkedList(batch2));
        }
        exec.shutdown();
        try {
            exec.awaitTermination(1L, TimeUnit.DAYS);
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        for (int i = 0; i < this.numIterations; ++i) {
            this.doIteration(batch2, numWordsSoFar, nextRandom);
        }
    }

    public void trainSentence(Pair<List<VocabWord>, Collection<VocabWord>> sentenceWithLabel, AtomicLong nextRandom, double alpha) {
        if (sentenceWithLabel == null || ((List)sentenceWithLabel.getFirst()).isEmpty()) {
            return;
        }
        for (int i = 0; i < ((List)sentenceWithLabel.getFirst()).size(); ++i) {
            nextRandom.set(nextRandom.get() * 25214903917L + 11L);
            this.dbow(i, sentenceWithLabel, (int)nextRandom.get() % this.window, nextRandom, alpha);
        }
    }

    public void dbow(int i, Pair<List<VocabWord>, Collection<VocabWord>> sentenceWithLabel, int b, AtomicLong nextRandom, double alpha) {
        VocabWord word = (VocabWord)((List)sentenceWithLabel.getFirst()).get(i);
        List sentence = (List)sentenceWithLabel.getFirst();
        List labels = (List)sentenceWithLabel.getSecond();
        if (word == null || sentence.isEmpty()) {
            return;
        }
        int end = this.window * 2 + 1 - b;
        for (int a = b; a < end; ++a) {
            int c;
            if (a == this.window || (c = i - this.window + a) < 0 || c >= labels.size()) continue;
            VocabWord lastWord = (VocabWord)labels.get(c);
            this.iterate(word, lastWord, nextRandom, alpha);
        }
    }

    @Override
    protected void addWords(List<VocabWord> sentence, AtomicLong nextRandom, List<VocabWord> currMiniBatch) {
        for (VocabWord word : sentence) {
            if (word == null) continue;
            if (this.sample > 0.0) {
                double numDocs = this.vectorizer.index().numDocuments();
                double ran = (Math.sqrt(word.getWordFrequency() / (this.sample * numDocs)) + 1.0) * (this.sample * numDocs) / word.getWordFrequency();
                if (ran < (double)(nextRandom.get() & 0xFFFFL) / 65536.0) continue;
                currMiniBatch.add(word);
                continue;
            }
            currMiniBatch.add(word);
        }
    }

    private void doIteration(Queue<Pair<List<VocabWord>, Collection<VocabWord>>> batch2, final AtomicLong numWordsSoFar, final AtomicLong nextRandom) {
        ActorSystem actorSystem = ActorSystem.create();
        final AtomicLong lastReport = new AtomicLong(System.currentTimeMillis());
        Parallelization.iterateInParallel(batch2, (Parallelization.RunnableWithParams)new Parallelization.RunnableWithParams<Pair<List<VocabWord>, Collection<VocabWord>>>(){

            public void run(Pair<List<VocabWord>, Collection<VocabWord>> sentenceWithLabel, Object[] args) {
                double alpha = Math.max(ParagraphVectors.this.minLearningRate, ParagraphVectors.this.alpha.get() * (1.0 - 1.0 * (double)numWordsSoFar.get() / (double)ParagraphVectors.this.totalWords));
                long diff = Math.abs(lastReport.get() - numWordsSoFar.get());
                if (numWordsSoFar.get() > 0L && diff >= 10000L) {
                    log.info("Words so far " + numWordsSoFar.get() + " with alpha at " + alpha);
                    lastReport.set(numWordsSoFar.get());
                }
                long increment = 0L;
                double diff2 = 0.0;
                ParagraphVectors.this.trainSentence(sentenceWithLabel, nextRandom, alpha);
                log.info("Train sentence avg took " + diff2 / (double)((List)sentenceWithLabel.getFirst()).size());
                numWordsSoFar.set(numWordsSoFar.get() + (increment += (long)((List)sentenceWithLabel.getFirst()).size()));
            }
        }, (ActorSystem)actorSystem);
    }

    public static class Builder
    extends Word2Vec.Builder {
        @Override
        public Builder index(InvertedIndex index) {
            super.index(index);
            return this;
        }

        @Override
        public Builder workers(int workers) {
            super.workers(workers);
            return this;
        }

        @Override
        public Builder sampling(double sample) {
            super.sampling(sample);
            return this;
        }

        @Override
        public Builder negativeSample(double negative) {
            super.negativeSample(negative);
            return this;
        }

        @Override
        public Builder minLearningRate(double minLearningRate) {
            super.minLearningRate(minLearningRate);
            return this;
        }

        @Override
        public Builder useAdaGrad(boolean useAdaGrad) {
            super.useAdaGrad(useAdaGrad);
            return this;
        }

        @Override
        public Builder vectorizer(TextVectorizer textVectorizer) {
            super.vectorizer(textVectorizer);
            return this;
        }

        @Override
        public Builder learningRateDecayWords(int learningRateDecayWords) {
            super.learningRateDecayWords(learningRateDecayWords);
            return this;
        }

        @Override
        public Builder batchSize(int batchSize) {
            super.batchSize(batchSize);
            return this;
        }

        @Override
        public Builder saveVocab(boolean saveVocab) {
            super.saveVocab(saveVocab);
            return this;
        }

        @Override
        public Builder seed(long seed) {
            super.seed(seed);
            return this;
        }

        @Override
        public Builder iterations(int iterations) {
            super.iterations(iterations);
            return this;
        }

        @Override
        public Builder learningRate(double lr) {
            super.learningRate(lr);
            return this;
        }

        @Override
        public Builder iterate(DocumentIterator iter) {
            super.iterate(iter);
            return this;
        }

        @Override
        public Builder vocabCache(VocabCache cache) {
            super.vocabCache(cache);
            return this;
        }

        @Override
        public Builder minWordFrequency(int minWordFrequency) {
            super.minWordFrequency(minWordFrequency);
            return this;
        }

        @Override
        public Builder tokenizerFactory(TokenizerFactory tokenizerFactory) {
            super.tokenizerFactory(tokenizerFactory);
            return this;
        }

        @Override
        public Builder layerSize(int layerSize) {
            super.layerSize(layerSize);
            return this;
        }

        @Override
        public Builder stopWords(List<String> stopWords) {
            super.stopWords(stopWords);
            return this;
        }

        @Override
        public Builder windowSize(int window) {
            super.windowSize(window);
            return this;
        }

        @Override
        public Builder iterate(SentenceIterator iter) {
            super.iterate(iter);
            return this;
        }

        @Override
        public Builder lookupTable(WeightLookupTable lookupTable) {
            super.lookupTable(lookupTable);
            return this;
        }

        @Override
        public ParagraphVectors build() {
            if (this.iter == null) {
                ParagraphVectors ret = new ParagraphVectors();
                ret.window = this.window;
                ret.alpha.set(this.lr);
                ret.vectorizer = this.textVectorizer;
                ret.stopWords = this.stopWords;
                ret.setVocab(this.vocabCache);
                ret.numIterations = this.iterations;
                ret.minWordFrequency = this.minWordFrequency;
                ret.seed = this.seed;
                ret.saveVocab = this.saveVocab;
                ret.batchSize = this.batchSize;
                ret.useAdaGrad = this.useAdaGrad;
                ret.minLearningRate = this.minLearningRate;
                ret.sample = this.sampling;
                ret.workers = this.workers;
                ret.invertedIndex = this.index;
                ret.lookupTable = this.lookupTable;
                try {
                    if (this.tokenizerFactory == null) {
                        this.tokenizerFactory = new UimaTokenizerFactory();
                    }
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
                if (this.vocabCache == null) {
                    this.vocabCache = new InMemoryLookupCache();
                    ret.setVocab(this.vocabCache);
                }
                if (this.lookupTable == null) {
                    this.lookupTable = new InMemoryLookupTable.Builder().negative(this.negative).useAdaGrad(this.useAdaGrad).lr(this.lr).cache(this.vocabCache).vectorLength(this.layerSize).build();
                }
                ret.docIter = this.docIter;
                ret.lookupTable = this.lookupTable;
                ret.tokenizerFactory = this.tokenizerFactory;
                return ret;
            }
            ParagraphVectors ret = new ParagraphVectors();
            ret.alpha.set(this.lr);
            ret.sentenceIter = this.iter;
            ret.window = this.window;
            ret.useAdaGrad = this.useAdaGrad;
            ret.minLearningRate = this.minLearningRate;
            ret.vectorizer = this.textVectorizer;
            ret.stopWords = this.stopWords;
            ret.minWordFrequency = this.minWordFrequency;
            ret.setVocab(this.vocabCache);
            ret.docIter = this.docIter;
            ret.minWordFrequency = this.minWordFrequency;
            ret.numIterations = this.iterations;
            ret.seed = this.seed;
            ret.numIterations = this.iterations;
            ret.saveVocab = this.saveVocab;
            ret.batchSize = this.batchSize;
            ret.sample = this.sampling;
            ret.workers = this.workers;
            ret.invertedIndex = this.index;
            ret.lookupTable = this.lookupTable;
            try {
                if (this.tokenizerFactory == null) {
                    this.tokenizerFactory = new UimaTokenizerFactory();
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            if (this.vocabCache == null) {
                this.vocabCache = new InMemoryLookupCache();
                ret.setVocab(this.vocabCache);
            }
            if (this.lookupTable == null) {
                this.lookupTable = new InMemoryLookupTable.Builder().negative(this.negative).useAdaGrad(this.useAdaGrad).lr(this.lr).cache(this.vocabCache).vectorLength(this.layerSize).build();
            }
            ret.lookupTable = this.lookupTable;
            ret.tokenizerFactory = this.tokenizerFactory;
            return ret;
        }
    }
}

