/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.word2vec;

import com.google.common.util.concurrent.AtomicDouble;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.bagofwords.vectorizer.TextVectorizer;
import org.deeplearning4j.bagofwords.vectorizer.TfidfVectorizer;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.nn.api.Persistable;
import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.stopwords.StopWords;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.UimaTokenizerFactory;
import org.deeplearning4j.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Word2Vec
implements Persistable {
    private static final long serialVersionUID = -2367495638286018038L;
    private transient TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
    private transient SentenceIterator sentenceIter;
    private transient DocumentIterator docIter;
    private transient VocabCache cache;
    private int batchSize = 1000;
    private int topNSize = 40;
    private double sample = 0.0;
    private long totalWords = 1L;
    private AtomicInteger rateOfChange = new AtomicInteger(0);
    private AtomicDouble alpha = new AtomicDouble(0.025);
    private int minWordFrequency = 5;
    private int window = 5;
    private int layerSize = 50;
    private transient RandomGenerator g;
    private static Logger log = LoggerFactory.getLogger(Word2Vec.class);
    private List<String> stopWords;
    private boolean shouldReset = true;
    private int numIterations = 1;
    public static final String UNK = "UNK";
    private long seed = 123L;
    private boolean saveVocab = false;
    private double minLearningRate = 0.01;
    private TextVectorizer vectorizer;
    private int learningRateDecayWords = 10000;
    private boolean useAdaGrad = false;

    public List<String> similarWordsInVocabTo(String word, double accuracy) {
        ArrayList<String> ret = new ArrayList<String>();
        for (String s : this.cache.words()) {
            String[] stringArray = new String[]{word, s};
            if (!(MathUtils.stringSimilarity((String[])stringArray) >= accuracy)) continue;
            ret.add(s);
        }
        return ret;
    }

    public int indexOf(String word) {
        return this.cache.indexOf(word);
    }

    public double[] getWordVector(String word) {
        int i = this.cache.indexOf(word);
        if (i < 0) {
            return this.cache.vector(UNK).ravel().data().asDouble();
        }
        return this.cache.vector(word).ravel().data().asDouble();
    }

    public INDArray getWordVectorMatrix(String word) {
        int i = this.cache.indexOf(word);
        if (i < 0) {
            return this.cache.vector(UNK);
        }
        return this.cache.vector(word);
    }

    public INDArray getWordVectorMatrixNormalized(String word) {
        int i = this.cache.indexOf(word);
        if (i < 0) {
            return this.cache.vector(UNK);
        }
        INDArray r = this.cache.vector(word);
        return r.div((Number)Nd4j.getBlasWrapper().nrm2(r));
    }

    public Collection<String> wordsNearest(String word, int n) {
        INDArray vec = Transforms.unitVec((INDArray)this.getWordVectorMatrix(word));
        if (this.cache instanceof InMemoryLookupCache) {
            InMemoryLookupCache l = (InMemoryLookupCache)this.cache;
            INDArray syn0 = l.getSyn0();
            INDArray weights = syn0.norm2(0).rdivi((Number)1).muli(vec);
            INDArray distances = syn0.mulRowVector(weights).sum(1);
            INDArray[] sorted = Nd4j.sortWithIndices((INDArray)distances, (int)0, (boolean)false);
            INDArray sort = sorted[0];
            ArrayList<String> ret = new ArrayList<String>();
            VocabWord word2 = this.cache.wordFor(word);
            if (n > sort.length()) {
                n = sort.length();
            }
            for (int i = 0; i < n + 1; ++i) {
                if (sort.getInt(new int[]{i}) == word2.getIndex()) continue;
                ret.add(this.cache.wordAtIndex(sort.getInt(new int[]{i})));
            }
            return ret;
        }
        if (vec == null) {
            return new ArrayList<String>();
        }
        Counter distances = new Counter();
        for (String s : this.cache.words()) {
            if (s.equals(word)) continue;
            INDArray otherVec = this.getWordVectorMatrix(s);
            double sim = Transforms.cosineSim((INDArray)vec, (INDArray)otherVec);
            distances.incrementCount((Object)s, sim);
        }
        distances.keepTopNKeys(n);
        return distances.keySet();
    }

    public List<String> analogyWords(String w1, String w2, String w3) {
        TreeSet<VocabWord> analogies = this.analogy(w1, w2, w3);
        ArrayList<String> ret = new ArrayList<String>();
        for (VocabWord w : analogies) {
            String w4 = this.cache.wordAtIndex(w.getIndex());
            ret.add(w4);
        }
        return ret;
    }

    private void insertTopN(String name, double score, List<VocabWord> wordsEntrys) {
        if (wordsEntrys.size() < this.topNSize) {
            VocabWord v = new VocabWord(score, name);
            v.setIndex(this.cache.indexOf(name));
            wordsEntrys.add(v);
            return;
        }
        double min = Double.MAX_VALUE;
        int minOffe = 0;
        int minIndex = -1;
        for (int i = 0; i < this.topNSize; ++i) {
            VocabWord wordEntry = wordsEntrys.get(i);
            if (!(min > wordEntry.getWordFrequency())) continue;
            min = wordEntry.getWordFrequency();
            minOffe = i;
            minIndex = wordEntry.getIndex();
        }
        if (score > min) {
            VocabWord w = new VocabWord(score, "parent");
            w.setIndex(minIndex);
            wordsEntrys.set(minOffe, w);
        }
    }

    public boolean hasWord(String word) {
        return this.cache.indexOf(word) >= 0;
    }

    public void fit() {
        boolean loaded = this.buildVocab();
        if (!loaded && this.saveVocab) {
            this.cache.saveVocab();
        }
        if (this.stopWords == null) {
            this.readStopWords();
        }
        log.info("Training word2vec multithreaded");
        if (this.sentenceIter != null) {
            this.sentenceIter.reset();
        }
        if (this.docIter != null) {
            this.docIter.reset();
        }
        ExecutorService service = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
        Collection<Integer> docs = this.vectorizer.index().allDocs();
        int tries = 0;
        while (docs.isEmpty()) {
            if (tries >= 3) {
                throw new IllegalStateException("Unable to train, no documents found");
            }
            log.warn("No documents found...waiting 10 seconds on try " + tries);
            try {
                Thread.sleep(10000L);
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            ++tries;
        }
        final AtomicInteger numSentencesProcessed = new AtomicInteger(0);
        this.totalWords = this.vectorizer.numWordsEncountered();
        this.totalWords *= (long)this.numIterations;
        log.info("Processing sentences...");
        ArrayList<Future<Void>> futures2 = new ArrayList<Future<Void>>();
        for (int i = 0; i < this.numIterations; ++i) {
            log.info("Training on " + docs.size());
            final AtomicLong atomicLong = new AtomicLong(5L);
            Iterator<List<VocabWord>> minibatchesIter = this.vectorizer.index().miniBatches();
            while (minibatchesIter.hasNext()) {
                final List<VocabWord> batch = minibatchesIter.next();
                futures2.add(service.submit(new Callable<Void>(){

                    @Override
                    public Void call() {
                        Word2Vec.this.trainSentence(batch, numSentencesProcessed, atomicLong);
                        return null;
                    }
                }));
            }
        }
        try {
            for (Future future : futures2) {
                future.get();
            }
            service.shutdown();
            while (!service.isTerminated()) {
                Thread.sleep(1000L);
            }
        }
        catch (Exception e) {
            Thread.currentThread().interrupt();
        }
    }

    public Set<VocabWord> distance(String word) {
        INDArray wordVector = this.getWordVectorMatrix(word);
        if (wordVector == null) {
            return null;
        }
        ArrayList<VocabWord> wordEntrys = new ArrayList<VocabWord>(this.topNSize);
        for (String name : this.cache.words()) {
            if (name.equals(word)) continue;
            INDArray tempVector = this.cache.vector(name);
            this.insertTopN(name, Nd4j.getBlasWrapper().dot(wordVector, tempVector), wordEntrys);
        }
        return new TreeSet<VocabWord>(wordEntrys);
    }

    public TreeSet<VocabWord> analogy(String word0, String word1, String word2) {
        INDArray wv0 = this.getWordVectorMatrix(word0);
        INDArray wv1 = this.getWordVectorMatrix(word1);
        INDArray wv2 = this.getWordVectorMatrix(word2);
        INDArray wordVector = wv1.sub(wv0).add(wv2);
        if (wv1 == null || wv2 == null || wv0 == null) {
            return null;
        }
        ArrayList<VocabWord> wordEntrys = new ArrayList<VocabWord>(this.topNSize);
        for (int i = 0; i < this.cache.numWords(); ++i) {
            String name = this.cache.wordAtIndex(i);
            if (name.equals(word0) || name.equals(word1) || name.equals(word2)) continue;
            INDArray tempVector = this.cache.vector(this.cache.wordAtIndex(i));
            double dist = Nd4j.getBlasWrapper().dot(wordVector, tempVector);
            this.insertTopN(name, dist, wordEntrys);
        }
        return new TreeSet<VocabWord>(wordEntrys);
    }

    public void setup() {
        log.info("Building binary tree");
        this.buildBinaryTree();
        log.info("Resetting weights");
        if (this.shouldReset) {
            this.resetWeights();
        }
    }

    public boolean buildVocab() {
        this.readStopWords();
        if (this.cache.vocabExists()) {
            log.info("Loading vocab...");
            this.cache.loadVocab();
            this.cache.resetWeights();
            return true;
        }
        if (this.vectorizer == null) {
            this.vectorizer = new TfidfVectorizer.Builder().cache(this.cache).iterate(this.docIter).iterate(this.sentenceIter).batchSize(this.batchSize).minWords(this.minWordFrequency).stopWords(this.stopWords).tokenize(this.tokenizerFactory).build();
        }
        this.vectorizer.fit();
        this.setup();
        return false;
    }

    public void plotTsne() {
        this.cache.plotVocab();
    }

    public void trainSentence(List<VocabWord> sentence, AtomicInteger numWordsSoFar, AtomicLong nextRandom) {
        if (sentence == null || sentence.isEmpty()) {
            return;
        }
        numWordsSoFar.set(numWordsSoFar.get() + sentence.size());
        this.rateOfChange.set(this.rateOfChange.get() + sentence.size());
        if (this.rateOfChange.get() >= this.learningRateDecayWords) {
            this.rateOfChange.set(0);
            if (!this.useAdaGrad) {
                this.alpha.set(Math.max(this.minLearningRate, this.alpha.get() * (1.0 - 1.0 * (double)numWordsSoFar.get() / (double)this.totalWords)));
                this.cache.setLearningRate(this.alpha.get());
            }
            log.info("Num words so far " + numWordsSoFar.get() + " alpha is " + this.alpha.get() + " out of " + this.totalWords);
        }
        for (int i = 0; i < sentence.size(); ++i) {
            nextRandom.set(nextRandom.get() * 25214903917L + 11L);
            this.skipGram(i, sentence, (int)nextRandom.get() % this.window, nextRandom);
        }
    }

    public void skipGram(int i, List<VocabWord> sentence, int b, AtomicLong nextRandom) {
        VocabWord word = sentence.get(i);
        if (word == null || sentence.isEmpty()) {
            return;
        }
        int end = this.window * 2 + 1 - b;
        for (int a = b; a < end; ++a) {
            int c;
            if (a == this.window || (c = i - this.window + a) < 0 || c >= sentence.size()) continue;
            VocabWord lastWord = sentence.get(c);
            this.iterate(word, lastWord, nextRandom);
        }
    }

    public void iterate(VocabWord w1, VocabWord w2, AtomicLong nextRandom) {
        this.cache.iterateSample(w1, w2, nextRandom);
    }

    private void buildBinaryTree() {
        log.info("Constructing priority queue");
        Huffman huffman = new Huffman(this.cache.vocabWords());
        huffman.build();
        log.info("Built tree");
    }

    private void resetWeights() {
        this.cache.resetWeights();
    }

    public double similarity(String word, String word2) {
        if (word.equals(word2)) {
            return 1.0;
        }
        INDArray vector = Transforms.unitVec((INDArray)this.getWordVectorMatrix(word));
        INDArray vector2 = Transforms.unitVec((INDArray)this.getWordVectorMatrix(word2));
        if (vector == null || vector2 == null) {
            return -1.0;
        }
        return Nd4j.getBlasWrapper().dot(vector, vector2);
    }

    private void readStopWords() {
        if (this.stopWords != null) {
            return;
        }
        this.stopWords = StopWords.getStopWords();
    }

    public void write(OutputStream os) {
        try {
            ObjectOutputStream dos = new ObjectOutputStream(os);
            dos.writeObject(this);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void load(InputStream is) {
        try {
            ObjectInputStream ois = new ObjectInputStream(is);
            Word2Vec vec = (Word2Vec)ois.readObject();
            this.alpha = vec.alpha;
            this.minWordFrequency = vec.minWordFrequency;
            this.sample = vec.sample;
            this.stopWords = vec.stopWords;
            this.topNSize = vec.topNSize;
            this.window = vec.window;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void setSentenceIter(SentenceIterator sentenceIter) {
        this.sentenceIter = sentenceIter;
        this.shouldReset = false;
    }

    public int getLayerSize() {
        return this.layerSize;
    }

    public void setLayerSize(int layerSize) {
        this.layerSize = layerSize;
    }

    public int getWindow() {
        return this.window;
    }

    public List<String> getStopWords() {
        return this.stopWords;
    }

    public synchronized SentenceIterator getSentenceIter() {
        return this.sentenceIter;
    }

    public TokenizerFactory getTokenizerFactory() {
        return this.tokenizerFactory;
    }

    public void setTokenizerFactory(TokenizerFactory tokenizerFactory) {
        this.tokenizerFactory = tokenizerFactory;
    }

    public VocabCache getCache() {
        return this.cache;
    }

    public void setCache(VocabCache cache) {
        this.cache = cache;
    }

    public static class Builder {
        private int minWordFrequency = 1;
        private int layerSize = 50;
        private SentenceIterator iter;
        private List<String> stopWords = StopWords.getStopWords();
        private int window = 5;
        private TokenizerFactory tokenizerFactory;
        private VocabCache vocabCache;
        private DocumentIterator docIter;
        private double lr = 0.25;
        private int iterations = 1;
        private long seed = 123L;
        private boolean saveVocab = false;
        private int batchSize = 1000;
        private int learningRateDecayWords = 10000;
        private boolean useAdaGrad = false;
        private TextVectorizer textVectorizer;
        private double minLearningRate = 0.01;
        private double negative = 0.0;
        private double sampling = 1.0E-5;

        public Builder sampling(double sample) {
            this.sampling = sample;
            return this;
        }

        public Builder negativeSample(double negative) {
            this.negative = negative;
            return this;
        }

        public Builder minLearningRate(double minLearningRate) {
            this.minLearningRate = minLearningRate;
            return this;
        }

        public Builder useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder vectorizer(TextVectorizer textVectorizer) {
            this.textVectorizer = textVectorizer;
            return this;
        }

        public Builder learningRateDecayWords(int learningRateDecayWords) {
            this.learningRateDecayWords = learningRateDecayWords;
            return this;
        }

        public Builder batchSize(int batchSize) {
            this.batchSize = batchSize;
            return this;
        }

        public Builder saveVocab(boolean saveVocab) {
            this.saveVocab = saveVocab;
            return this;
        }

        public Builder seed(long seed) {
            this.seed = seed;
            return this;
        }

        public Builder iterations(int iterations) {
            this.iterations = iterations;
            return this;
        }

        public Builder learningRate(double lr) {
            this.lr = lr;
            return this;
        }

        public Builder iterate(DocumentIterator iter) {
            this.docIter = iter;
            return this;
        }

        public Builder vocabCache(VocabCache cache) {
            this.vocabCache = cache;
            return this;
        }

        public Builder minWordFrequency(int minWordFrequency) {
            this.minWordFrequency = minWordFrequency;
            return this;
        }

        public Builder tokenizerFactory(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder layerSize(int layerSize) {
            this.layerSize = layerSize;
            return this;
        }

        public Builder stopWords(List<String> stopWords) {
            this.stopWords = stopWords;
            return this;
        }

        public Builder windowSize(int window) {
            this.window = window;
            return this;
        }

        public Builder iterate(SentenceIterator iter) {
            this.iter = iter;
            return this;
        }

        public Word2Vec build() {
            if (this.iter == null) {
                Word2Vec ret = new Word2Vec();
                ret.layerSize = this.layerSize;
                ret.window = this.window;
                ret.alpha.set(this.lr);
                ret.vectorizer = this.textVectorizer;
                ret.stopWords = this.stopWords;
                ret.setCache(this.vocabCache);
                ret.numIterations = this.iterations;
                ret.minWordFrequency = this.minWordFrequency;
                ret.seed = this.seed;
                ret.saveVocab = this.saveVocab;
                ret.batchSize = this.batchSize;
                ret.useAdaGrad = this.useAdaGrad;
                ret.minLearningRate = this.minLearningRate;
                ret.sample = this.sampling;
                try {
                    if (this.tokenizerFactory == null) {
                        this.tokenizerFactory = new UimaTokenizerFactory();
                    }
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
                if (this.vocabCache == null) {
                    this.vocabCache = new InMemoryLookupCache.Builder().negative(this.negative).useAdaGrad(this.useAdaGrad).lr(this.lr).vectorLength(this.layerSize).build();
                    ret.cache = this.vocabCache;
                }
                ret.docIter = this.docIter;
                ret.tokenizerFactory = this.tokenizerFactory;
                return ret;
            }
            Word2Vec ret = new Word2Vec();
            ret.alpha.set(this.lr);
            ret.layerSize = this.layerSize;
            ret.sentenceIter = this.iter;
            ret.window = this.window;
            ret.useAdaGrad = this.useAdaGrad;
            ret.minLearningRate = this.minLearningRate;
            ret.vectorizer = this.textVectorizer;
            ret.stopWords = this.stopWords;
            ret.minWordFrequency = this.minWordFrequency;
            ret.setCache(this.vocabCache);
            ret.docIter = this.docIter;
            ret.minWordFrequency = this.minWordFrequency;
            ret.numIterations = this.iterations;
            ret.seed = this.seed;
            ret.numIterations = this.iterations;
            ret.saveVocab = this.saveVocab;
            ret.batchSize = this.batchSize;
            ret.sample = this.sampling;
            try {
                if (this.tokenizerFactory == null) {
                    this.tokenizerFactory = new UimaTokenizerFactory();
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            if (this.vocabCache == null) {
                this.vocabCache = new InMemoryLookupCache.Builder().negative(this.negative).useAdaGrad(this.useAdaGrad).lr(this.lr).vectorLength(this.layerSize).build();
                ret.cache = this.vocabCache;
            }
            ret.tokenizerFactory = this.tokenizerFactory;
            return ret;
        }
    }
}

