/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.glove;

import akka.actor.ActorRef;
import akka.actor.ActorSystem;
import akka.actor.Props;
import akka.routing.RoundRobinPool;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.CounterMap;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.glove.actor.CoOccurrenceActor;
import org.deeplearning4j.models.glove.actor.SentenceWork;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.deeplearning4j.text.movingwindow.Util;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CoOccurrences
implements Serializable {
    private transient TokenizerFactory tokenizerFactory;
    private transient SentenceIterator sentenceIterator;
    private int windowSize = 15;
    protected transient VocabCache cache;
    protected InvertedIndex index;
    protected transient ActorSystem trainingSystem;
    protected boolean symmetric = true;
    private Counter<Integer> sentenceOccurrences = Util.parallelCounter();
    private CounterMap<String, String> coOCurreneCounts = Util.parallelCounterMap();
    private static Logger log = LoggerFactory.getLogger(CoOccurrences.class);
    private List<Pair<String, String>> coOccurrences;

    private CoOccurrences() {
    }

    public CoOccurrences(TokenizerFactory tokenizerFactory, SentenceIterator sentenceIterator, int windowSize, VocabCache cache, CounterMap<String, String> coOCurreneCounts, boolean symmetric) {
        this.tokenizerFactory = tokenizerFactory;
        this.sentenceIterator = sentenceIterator;
        this.windowSize = windowSize;
        this.cache = cache;
        this.coOCurreneCounts = coOCurreneCounts;
        this.symmetric = symmetric;
    }

    public void fit() {
        if (this.trainingSystem == null) {
            this.trainingSystem = ActorSystem.create();
        }
        AtomicInteger processed = new AtomicInteger(0);
        ActorRef actor = this.trainingSystem.actorOf(new RoundRobinPool(Runtime.getRuntime().availableProcessors()).props(Props.create(CoOccurrenceActor.class, (Object[])new Object[]{processed, this.tokenizerFactory, this.windowSize, this.cache, this.coOCurreneCounts, this.symmetric, this.sentenceOccurrences})));
        this.sentenceIterator.reset();
        AtomicInteger queued = new AtomicInteger(0);
        int id = 0;
        while (this.sentenceIterator.hasNext()) {
            actor.tell((Object)new SentenceWork(id, this.sentenceIterator.nextSentence()), actor);
            ++id;
            queued.incrementAndGet();
        }
        try {
            Thread.sleep(5000L);
        }
        catch (InterruptedException e) {
            e.printStackTrace();
        }
        while (processed.get() < queued.get()) {
            try {
                Thread.sleep(10000L);
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        this.trainingSystem.shutdown();
        this.trainingSystem = null;
        log.info("Done processing co occurrences: ended with " + this.numCoOccurrences());
    }

    public Iterator<List<Pair<VocabWord, VocabWord>>> coOccurrenceIteratorVocabBatch(int batchSize) {
        return new CoOccurrenceBatchIterator(batchSize);
    }

    public Iterator<Pair<VocabWord, VocabWord>> coOccurrenceIteratorVocab() {
        return new CoOccurrenceIterator();
    }

    public static CoOccurrences load(InputStream from) {
        CoOccurrences ret = new CoOccurrences();
        ret.coOccurrences = new ArrayList<Pair<String, String>>();
        CounterMap counter = new CounterMap();
        InputStreamReader inputStream = new InputStreamReader(from);
        LineIterator iter = IOUtils.lineIterator((Reader)inputStream);
        while (iter.hasNext()) {
            String line = iter.nextLine();
            String[] split = line.split(" ");
            if (split.length < 3 || split[0].isEmpty() || split[1].isEmpty()) continue;
            ret.coOccurrences.add((Pair<String, String>)new Pair((Object)split[0], (Object)split[1]));
            counter.incrementCount((Object)split[0], (Object)split[1], Double.parseDouble(split[2]));
        }
        ret.coOCurreneCounts = counter;
        return ret;
    }

    public Counter<Integer> getSentenceOccurrences() {
        return this.sentenceOccurrences;
    }

    public void setSentenceOccurrences(Counter<Integer> sentenceOccurrences) {
        this.sentenceOccurrences = sentenceOccurrences;
    }

    public List<Pair<String, String>> coOccurrenceList() {
        if (this.coOccurrences != null) {
            return this.coOccurrences;
        }
        Iterator<Pair<String, String>> pairIter = this.coOccurrenceIterator();
        ArrayList<Pair<String, String>> pairList = new ArrayList<Pair<String, String>>();
        while (pairIter.hasNext()) {
            pairList.add(pairIter.next());
        }
        return pairList;
    }

    public List<Pair<String, String>> randomizedList() {
        List<Pair<String, String>> coOccurrences = this.coOccurrenceList();
        Collections.shuffle(coOccurrences);
        return coOccurrences;
    }

    public int numCoOccurrences() {
        return this.coOCurreneCounts.totalSize();
    }

    public double count(String w1, String w2) {
        return this.coOCurreneCounts.getCount((Object)w1, (Object)w2);
    }

    public Iterator<Pair<String, String>> coOccurrenceIterator() {
        return this.coOCurreneCounts.getPairIterator();
    }

    public CounterMap<String, String> getCoOCurreneCounts() {
        return this.coOCurreneCounts;
    }

    public void setCoOCurreneCounts(CounterMap<String, String> coOCurreneCounts) {
        this.coOCurreneCounts = coOCurreneCounts;
    }

    public static class Builder {
        private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
        private SentenceIterator sentenceIterator;
        private int windowSize = 15;
        private VocabCache cache;
        private CounterMap<String, String> coOCurreneCounts = Util.parallelCounterMap();
        private boolean symmetric = true;

        public Builder symmetric(boolean symmetric) {
            this.symmetric = symmetric;
            return this;
        }

        public Builder tokenizer(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder iterate(SentenceIterator sentenceIterator) {
            this.sentenceIterator = sentenceIterator;
            return this;
        }

        public Builder windowSize(int windowSize) {
            this.windowSize = windowSize;
            return this;
        }

        public Builder cache(VocabCache cache) {
            this.cache = cache;
            return this;
        }

        public Builder coOCurreneCounts(CounterMap<String, String> coOCurreneCounts) {
            this.coOCurreneCounts = coOCurreneCounts;
            return this;
        }

        public CoOccurrences build() {
            if (this.cache == null) {
                throw new IllegalArgumentException("Vocab cache must not be null!");
            }
            if (this.sentenceIterator == null) {
                throw new IllegalArgumentException("Sentence iterator must not be null");
            }
            return new CoOccurrences(this.tokenizerFactory, this.sentenceIterator, this.windowSize, this.cache, this.coOCurreneCounts, this.symmetric);
        }
    }

    public class CoOccurrenceIterator
    implements Iterator<Pair<VocabWord, VocabWord>> {
        private Iterator<Pair<String, String>> iter;

        public CoOccurrenceIterator() {
            this.iter = CoOccurrences.this.coOccurrenceIterator();
        }

        @Override
        public boolean hasNext() {
            return this.iter.hasNext();
        }

        @Override
        public Pair<VocabWord, VocabWord> next() {
            Pair<String, String> next = this.iter.next();
            Pair ret = new Pair((Object)CoOccurrences.this.cache.wordFor((String)next.getFirst()), (Object)CoOccurrences.this.cache.wordFor((String)next.getSecond()));
            return ret;
        }

        @Override
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }

    public class CoOccurrenceBatchIterator
    implements Iterator<List<Pair<VocabWord, VocabWord>>> {
        private Iterator<Pair<VocabWord, VocabWord>> iter;
        private int batchSize;

        public CoOccurrenceBatchIterator(int batchSize) {
            this.iter = CoOccurrences.this.coOccurrenceIteratorVocab();
            this.batchSize = 100;
            this.batchSize = batchSize;
        }

        public CoOccurrenceBatchIterator() {
            this(100);
        }

        @Override
        public boolean hasNext() {
            return this.iter.hasNext();
        }

        @Override
        public List<Pair<VocabWord, VocabWord>> next() {
            ArrayList<Pair<VocabWord, VocabWord>> list = new ArrayList<Pair<VocabWord, VocabWord>>(this.batchSize);
            for (int i = 0; i < this.batchSize && this.iter.hasNext(); ++i) {
                Pair<VocabWord, VocabWord> next = this.iter.next();
                list.add(next);
            }
            return list;
        }

        @Override
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }
}

