/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.word2vec.wordstore;

import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.LockSupport;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VocabConstructor<T extends SequenceElement> {
    private List<VocabSource<T>> sources = new ArrayList<VocabSource<T>>();
    private VocabCache<T> cache;
    private Collection<String> stopWords;
    private boolean useAdaGrad = false;
    private boolean fetchLabels = false;
    private int limit;
    private AtomicLong seqCount = new AtomicLong(0L);
    private InvertedIndex<T> index;
    private boolean enableScavenger = false;
    private T unk;
    private boolean allowParallelBuilder = true;
    protected static final Logger log = LoggerFactory.getLogger(VocabConstructor.class);

    private VocabConstructor() {
    }

    protected WeightLookupTable<T> buildExtendedLookupTable() {
        return null;
    }

    protected VocabCache<T> buildExtendedVocabulary() {
        return null;
    }

    public VocabCache<T> buildMergedVocabulary(@NonNull WordVectors wordVectors, boolean fetchLabels) {
        if (wordVectors == null) {
            throw new NullPointerException("wordVectors");
        }
        return this.buildMergedVocabulary(wordVectors.vocab(), fetchLabels);
    }

    public long getNumberOfSequences() {
        return this.seqCount.get();
    }

    public VocabCache<T> buildMergedVocabulary(@NonNull VocabCache<T> vocabCache, boolean fetchLabels) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache");
        }
        if (this.cache == null) {
            this.cache = new AbstractCache.Builder().build();
        }
        for (int t = 0; t < vocabCache.numWords(); ++t) {
            String label = vocabCache.wordAtIndex(t);
            if (label == null) continue;
            T element = vocabCache.wordFor(label);
            if (!fetchLabels && ((SequenceElement)element).isLabel()) continue;
            this.cache.addToken(element);
            this.cache.addWordToIndex(((SequenceElement)element).getIndex(), ((SequenceElement)element).getLabel());
            this.cache.putVocabWord(((SequenceElement)element).getLabel());
        }
        if (this.cache.numWords() == 0) {
            throw new IllegalStateException("Source VocabCache has no indexes available, transfer is impossible");
        }
        log.info("Vocab size before labels: " + this.cache.numWords());
        if (fetchLabels) {
            for (VocabSource<T> source : this.sources) {
                SequenceIterator<T> iterator = source.getIterator();
                iterator.reset();
                while (iterator.hasMoreSequences()) {
                    Sequence<T> sequence = iterator.nextSequence();
                    this.seqCount.incrementAndGet();
                    if (sequence.getSequenceLabels() == null) continue;
                    for (SequenceElement label : sequence.getSequenceLabels()) {
                        if (this.cache.containsWord(label.getLabel())) continue;
                        label.markAsLabel(true);
                        label.setSpecial(true);
                        label.setIndex(this.cache.numWords());
                        this.cache.addToken(label);
                        this.cache.addWordToIndex(label.getIndex(), label.getLabel());
                        this.cache.putVocabWord(label.getLabel());
                    }
                }
            }
        }
        log.info("Vocab size after labels: " + this.cache.numWords());
        return this.cache;
    }

    public VocabCache<T> buildJointVocabulary(boolean resetCounters, boolean buildHuffmanTree) {
        long lastTime = System.currentTimeMillis();
        long lastSequences = 0L;
        long lastElements = 0L;
        long startTime = lastTime;
        long startWords = 0L;
        AtomicLong parsedCount = new AtomicLong(0L);
        if (resetCounters && buildHuffmanTree) {
            throw new IllegalStateException("You can't reset counters and build Huffman tree at the same time!");
        }
        if (this.cache == null) {
            this.cache = new AbstractCache.Builder().build();
        }
        log.debug("Target vocab size before building: [" + this.cache.numWords() + "]");
        AtomicLong loopCounter = new AtomicLong(0L);
        AbstractCache topHolder = new AbstractCache.Builder().minElementFrequency(0).build();
        int cnt = 0;
        int numProc = Runtime.getRuntime().availableProcessors();
        int numThreads = Math.max(numProc / 2, 2);
        ThreadPoolExecutor executorService = new ThreadPoolExecutor(numThreads, numThreads, 0L, TimeUnit.MILLISECONDS, new LinkedTransferQueue<Runnable>());
        AtomicLong execCounter = new AtomicLong(0L);
        AtomicLong finCounter = new AtomicLong(0L);
        for (VocabSource<T> source : this.sources) {
            Iterator iterator = source.getIterator();
            iterator.reset();
            log.debug("Trying source iterator: [" + cnt + "]");
            log.debug("Target vocab size before building: [" + this.cache.numWords() + "]");
            ++cnt;
            AbstractCache tempHolder = new AbstractCache.Builder().build();
            ArrayList timesHasNext = new ArrayList();
            ArrayList timesNext = new ArrayList();
            int sequences = 0;
            long time3 = 0L;
            while (iterator.hasMoreSequences()) {
                Sequence document = iterator.nextSequence();
                this.seqCount.incrementAndGet();
                parsedCount.addAndGet(document.size());
                tempHolder.incrementTotalDocCount();
                execCounter.incrementAndGet();
                VocabRunnable runnable = new VocabRunnable(tempHolder, document, finCounter, loopCounter);
                executorService.execute(runnable);
                if (!this.allowParallelBuilder) {
                    while (execCounter.get() != finCounter.get()) {
                        LockSupport.parkNanos(1000L);
                    }
                }
                while (execCounter.get() - finCounter.get() > (long)numProc) {
                    try {
                        Thread.sleep(1L);
                    }
                    catch (Exception exception) {}
                }
                ++sequences;
                if (this.seqCount.get() % 100000L == 0L) {
                    long currentTime = System.currentTimeMillis();
                    long currentSequences = this.seqCount.get();
                    long currentElements = parsedCount.get();
                    double seconds = (double)(currentTime - lastTime) / 1000.0;
                    double seqPerSec = (double)(currentSequences - lastSequences) / seconds;
                    double elPerSec = (double)(currentElements - lastElements) / seconds;
                    log.info("Sequences checked: [{}]; Current vocabulary size: [{}]; Sequences/sec: {}; Words/sec: {};", new Object[]{this.seqCount.get(), tempHolder.numWords(), String.format("%.2f", seqPerSec), String.format("%.2f", elPerSec)});
                    lastTime = currentTime;
                    lastElements = currentElements;
                    lastSequences = currentSequences;
                }
                if (!this.enableScavenger || loopCounter.get() < 2000000L || tempHolder.numWords() <= 10000000) continue;
                log.info("Starting scavenger...");
                while (execCounter.get() != finCounter.get()) {
                    try {
                        Thread.sleep(2L);
                    }
                    catch (Exception exception) {}
                }
                this.filterVocab(tempHolder, Math.max(1, source.getMinWordFrequency() / 2));
                loopCounter.set(0L);
            }
            log.debug("Wating till all processes stop...");
            while (execCounter.get() != finCounter.get()) {
                try {
                    Thread.sleep(2L);
                }
                catch (Exception exception) {}
            }
            log.debug("Vocab size before truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + this.seqCount.get() + "], counter: [" + parsedCount.get() + "]");
            if (source.getMinWordFrequency() > 0) {
                this.filterVocab(tempHolder, source.getMinWordFrequency());
            }
            log.debug("Vocab size after truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + this.seqCount.get() + "], counter: [" + parsedCount.get() + "]");
            topHolder.importVocabulary(tempHolder);
        }
        System.gc();
        System.gc();
        try {
            Thread.sleep(1000L);
        }
        catch (Exception exception) {
            // empty catch block
        }
        this.cache.importVocabulary(topHolder);
        if (this.unk != null) {
            log.info("Adding UNK element to vocab...");
            ((SequenceElement)this.unk).setSpecial(true);
            this.cache.addToken(this.unk);
        }
        if (resetCounters) {
            for (SequenceElement element : this.cache.vocabWords()) {
                element.setElementFrequency(0L);
            }
            this.cache.updateWordsOccurencies();
        }
        if (buildHuffmanTree) {
            Huffman huffman = new Huffman(this.cache.vocabWords());
            huffman.build();
            huffman.applyIndexes(this.cache);
            if (this.limit > 0) {
                LinkedBlockingQueue<String> labelsToRemove = new LinkedBlockingQueue<String>();
                for (SequenceElement element : this.cache.vocabWords()) {
                    if (element.getIndex() <= this.limit || element.isSpecial() || element.isLabel()) continue;
                    labelsToRemove.add(element.getLabel());
                }
                for (String label : labelsToRemove) {
                    this.cache.removeElement(label);
                }
            }
        }
        executorService.shutdown();
        System.gc();
        System.gc();
        try {
            Thread.sleep(1000L);
        }
        catch (Exception huffman) {
            // empty catch block
        }
        long endSequences = this.seqCount.get();
        long endTime = System.currentTimeMillis();
        double seconds = (double)(endTime - startTime) / 1000.0;
        double seqPerSec = (double)endSequences / seconds;
        log.info("Sequences checked: [{}], Current vocabulary size: [{}]; Sequences/sec: [{}];", new Object[]{this.seqCount.get(), this.cache.numWords(), String.format("%.2f", seqPerSec)});
        return this.cache;
    }

    protected void filterVocab(AbstractCache<T> cache, int minWordFrequency) {
        int numWords = cache.numWords();
        LinkedBlockingQueue<String> labelsToRemove = new LinkedBlockingQueue<String>();
        for (SequenceElement element : cache.vocabWords()) {
            if (!(element.getElementFrequency() < (double)minWordFrequency) || element.isSpecial() || element.isLabel()) continue;
            labelsToRemove.add(element.getLabel());
        }
        for (String label : labelsToRemove) {
            cache.removeElement(label);
        }
        log.debug("Scavenger: Words before: {}; Words after: {};", (Object)numWords, (Object)cache.numWords());
    }

    protected class VocabRunnable
    implements Runnable {
        private final AtomicLong finalCounter;
        private final Sequence<T> document;
        private final AbstractCache<T> targetVocab;
        private final AtomicLong loopCounter;

        public VocabRunnable(@NonNull AbstractCache<T> targetVocab, @NonNull Sequence<T> sequence, @NonNull AtomicLong finalCounter, AtomicLong loopCounter) {
            if (targetVocab == null) {
                throw new NullPointerException("targetVocab");
            }
            if (sequence == null) {
                throw new NullPointerException("sequence");
            }
            if (finalCounter == null) {
                throw new NullPointerException("finalCounter");
            }
            if (loopCounter == null) {
                throw new NullPointerException("loopCounter");
            }
            this.finalCounter = finalCounter;
            this.document = sequence;
            this.targetVocab = targetVocab;
            this.loopCounter = loopCounter;
        }

        @Override
        public void run() {
            HashMap<String, AtomicLong> seqMap = new HashMap<String, AtomicLong>();
            if (VocabConstructor.this.fetchLabels && this.document.getSequenceLabels() != null) {
                for (SequenceElement labelWord : this.document.getSequenceLabels()) {
                    if (this.targetVocab.hasToken(labelWord.getLabel())) continue;
                    labelWord.setSpecial(true);
                    labelWord.markAsLabel(true);
                    labelWord.setElementFrequency(1L);
                    this.targetVocab.addToken(labelWord);
                }
            }
            List<String> tokens = this.document.asLabels();
            for (String token : tokens) {
                Object element;
                if (VocabConstructor.this.stopWords != null && VocabConstructor.this.stopWords.contains(token) || token == null || token.isEmpty()) continue;
                if (!this.targetVocab.containsWord(token)) {
                    element = this.document.getElementByLabel(token);
                    ((SequenceElement)element).setElementFrequency(1L);
                    ((SequenceElement)element).setSequencesCount(1L);
                    this.targetVocab.addToken(element);
                    this.loopCounter.incrementAndGet();
                    seqMap.put(token, new AtomicLong(0L));
                    continue;
                }
                this.targetVocab.incrementWordCount(token);
                if (!seqMap.containsKey(token)) {
                    seqMap.put(token, new AtomicLong(1L));
                    element = this.targetVocab.wordFor(token);
                    ((SequenceElement)element).incrementSequencesCount();
                }
                if (VocabConstructor.this.index == null) continue;
                if (this.document.getSequenceLabel() != null) {
                    VocabConstructor.this.index.addWordsToDoc(VocabConstructor.this.index.numDocuments(), this.document.getElements(), this.document.getSequenceLabel());
                    continue;
                }
                VocabConstructor.this.index.addWordsToDoc(VocabConstructor.this.index.numDocuments(), this.document.getElements());
            }
            this.finalCounter.incrementAndGet();
        }
    }

    private static class VocabSource<T extends SequenceElement> {
        @NonNull
        private SequenceIterator<T> iterator;
        @NonNull
        private int minWordFrequency;

        @ConstructorProperties(value={"iterator", "minWordFrequency"})
        public VocabSource(@NonNull SequenceIterator<T> iterator, @NonNull int minWordFrequency) {
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            this.iterator = iterator;
            this.minWordFrequency = minWordFrequency;
        }

        @NonNull
        public SequenceIterator<T> getIterator() {
            return this.iterator;
        }

        @NonNull
        public int getMinWordFrequency() {
            return this.minWordFrequency;
        }

        public void setIterator(@NonNull SequenceIterator<T> iterator) {
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            this.iterator = iterator;
        }

        public void setMinWordFrequency(@NonNull int minWordFrequency) {
            this.minWordFrequency = minWordFrequency;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof VocabSource)) {
                return false;
            }
            VocabSource other = (VocabSource)o;
            if (!other.canEqual(this)) {
                return false;
            }
            SequenceIterator<T> this$iterator = this.getIterator();
            SequenceIterator<T> other$iterator = other.getIterator();
            if (this$iterator == null ? other$iterator != null : !this$iterator.equals(other$iterator)) {
                return false;
            }
            return this.getMinWordFrequency() == other.getMinWordFrequency();
        }

        protected boolean canEqual(Object other) {
            return other instanceof VocabSource;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            SequenceIterator<T> $iterator = this.getIterator();
            result = result * 59 + ($iterator == null ? 43 : $iterator.hashCode());
            result = result * 59 + this.getMinWordFrequency();
            return result;
        }

        public String toString() {
            return "VocabConstructor.VocabSource(iterator=" + this.getIterator() + ", minWordFrequency=" + this.getMinWordFrequency() + ")";
        }
    }

    public static class Builder<T extends SequenceElement> {
        private List<VocabSource<T>> sources = new ArrayList<VocabSource<T>>();
        private VocabCache<T> cache;
        private Collection<String> stopWords = new ArrayList<String>();
        private boolean useAdaGrad = false;
        private boolean fetchLabels = false;
        private InvertedIndex<T> index;
        private int limit;
        private boolean enableScavenger = false;
        private T unk;
        private boolean allowParallelBuilder = true;

        public Builder<T> setEntriesLimit(int limit) {
            this.limit = limit;
            return this;
        }

        public Builder<T> allowParallelTokenization(boolean reallyAllow) {
            this.allowParallelBuilder = reallyAllow;
            return this;
        }

        protected Builder<T> useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder<T> setTargetVocabCache(@NonNull VocabCache<T> cache) {
            if (cache == null) {
                throw new NullPointerException("cache");
            }
            this.cache = cache;
            return this;
        }

        public Builder<T> addSource(@NonNull SequenceIterator<T> iterator, int minElementFrequency) {
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            this.sources.add(new VocabSource<T>(iterator, minElementFrequency));
            return this;
        }

        public Builder<T> setStopWords(@NonNull Collection<String> stopWords) {
            if (stopWords == null) {
                throw new NullPointerException("stopWords");
            }
            this.stopWords = stopWords;
            return this;
        }

        public Builder<T> fetchLabels(boolean reallyFetch) {
            this.fetchLabels = reallyFetch;
            return this;
        }

        public Builder<T> setIndex(InvertedIndex<T> index) {
            this.index = index;
            return this;
        }

        public Builder<T> enableScavenger(boolean reallyEnable) {
            this.enableScavenger = reallyEnable;
            return this;
        }

        public Builder<T> setUnk(T unk) {
            this.unk = unk;
            return this;
        }

        public VocabConstructor<T> build() {
            VocabConstructor constructor = new VocabConstructor();
            constructor.sources = this.sources;
            constructor.cache = this.cache;
            constructor.stopWords = this.stopWords;
            constructor.useAdaGrad = this.useAdaGrad;
            constructor.fetchLabels = this.fetchLabels;
            constructor.limit = this.limit;
            constructor.index = this.index;
            constructor.enableScavenger = this.enableScavenger;
            constructor.unk = this.unk;
            constructor.allowParallelBuilder = this.allowParallelBuilder;
            return constructor;
        }
    }
}

