/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.glove;

import java.io.File;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.deeplearning4j.models.glove.count.ASCIICoOccurrenceWriter;
import org.deeplearning4j.models.glove.count.BinaryCoOccurrenceReader;
import org.deeplearning4j.models.glove.count.BinaryCoOccurrenceWriter;
import org.deeplearning4j.models.glove.count.CoOccurrenceWeight;
import org.deeplearning4j.models.glove.count.CoOccurrenceWriter;
import org.deeplearning4j.models.glove.count.CountMap;
import org.deeplearning4j.models.glove.count.RoundCount;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.iterators.FilteredSequenceIterator;
import org.deeplearning4j.models.sequencevectors.iterators.SynchronizedSequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator;
import org.deeplearning4j.util.DL4JFileUtils;
import org.deeplearning4j.util.ThreadUtils;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AbstractCoOccurrences<T extends SequenceElement>
implements Serializable {
    protected boolean symmetric;
    protected int windowSize;
    protected VocabCache<T> vocabCache;
    protected SequenceIterator<T> sequenceIterator;
    protected int workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
    protected File targetFile;
    protected ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
    protected long memory_threshold = 0L;
    private ShadowCopyThread shadowThread;
    private volatile CountMap<T> coOccurrenceCounts = new CountMap();
    private AtomicLong processedSequences = new AtomicLong(0L);
    protected static final Logger logger = LoggerFactory.getLogger(AbstractCoOccurrences.class);

    private AbstractCoOccurrences() {
    }

    public double getCoOccurrenceCount(@NonNull T element1, @NonNull T element2) {
        if (element1 == null) {
            throw new NullPointerException("element1 is marked @NonNull but is null");
        }
        if (element2 == null) {
            throw new NullPointerException("element2 is marked @NonNull but is null");
        }
        return this.coOccurrenceCounts.getCount(element1, element2);
    }

    protected long getMemoryFootprint() {
        try {
            this.lock.readLock().lock();
            long l = (long)this.coOccurrenceCounts.size() * 24L * 5L;
            return l;
        }
        finally {
            this.lock.readLock().unlock();
        }
    }

    protected long getMemoryThreshold() {
        return this.memory_threshold / 2L;
    }

    public void fit() {
        int x;
        this.shadowThread = new ShadowCopyThread();
        this.shadowThread.start();
        this.sequenceIterator.reset();
        ArrayList<CoOccurrencesCalculatorThread> threads = new ArrayList<CoOccurrencesCalculatorThread>();
        for (x = 0; x < this.workers; ++x) {
            threads.add(x, new CoOccurrencesCalculatorThread(x, new FilteredSequenceIterator<T>(new SynchronizedSequenceIterator<T>(this.sequenceIterator), this.vocabCache), this.processedSequences));
            ((CoOccurrencesCalculatorThread)threads.get(x)).start();
        }
        for (x = 0; x < this.workers; ++x) {
            try {
                ((CoOccurrencesCalculatorThread)threads.get(x)).join();
                continue;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        this.shadowThread.finish();
        logger.info("CoOccurrences map was built.");
    }

    public Iterator<Pair<Pair<T, T>, Double>> iterator() {
        SynchronizedSentenceIterator iterator;
        try {
            iterator = new SynchronizedSentenceIterator(new PrefetchingSentenceIterator.Builder(new BasicLineIterator(this.targetFile)).setFetchSize(500000).build());
        }
        catch (Exception e) {
            logger.error("Target file was not found on last stage!");
            throw new RuntimeException(e);
        }
        return new Iterator<Pair<Pair<T, T>, Double>>(){

            @Override
            public boolean hasNext() {
                return iterator.hasNext();
            }

            @Override
            public Pair<Pair<T, T>, Double> next() {
                String line = iterator.nextSentence();
                String[] strings = line.split(" ");
                Object element1 = AbstractCoOccurrences.this.vocabCache.elementAtIndex(Integer.valueOf(strings[0]));
                Object element2 = AbstractCoOccurrences.this.vocabCache.elementAtIndex(Integer.valueOf(strings[1]));
                Double weight = Double.valueOf(strings[2]);
                return new Pair((Object)new Pair(element1, element2), (Object)weight);
            }

            @Override
            public void remove() {
                throw new UnsupportedOperationException("remove() method can't be supported on read-only interface");
            }
        };
    }

    private class ShadowCopyThread
    extends Thread
    implements Runnable {
        private AtomicBoolean isFinished = new AtomicBoolean(false);
        private AtomicBoolean isTerminate = new AtomicBoolean(false);
        private AtomicBoolean isInvoked = new AtomicBoolean(false);
        private AtomicBoolean shouldInvoke = new AtomicBoolean(false);
        private File[] tempFiles;
        private RoundCount counter;

        public ShadowCopyThread() {
            try {
                this.counter = new RoundCount(1);
                this.tempFiles = new File[2];
                this.tempFiles[0] = DL4JFileUtils.createTempFile((String)"aco", (String)"tmp");
                this.tempFiles[1] = DL4JFileUtils.createTempFile((String)"aco", (String)"tmp");
                this.tempFiles[0].deleteOnExit();
                this.tempFiles[1].deleteOnExit();
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            this.setName("ACO ShadowCopy thread");
        }

        @Override
        public void run() {
            while (!this.isFinished.get() && !this.isTerminate.get()) {
                if (AbstractCoOccurrences.this.getMemoryFootprint() > AbstractCoOccurrences.this.getMemoryThreshold() || this.shouldInvoke.get() && !this.isInvoked.get()) {
                    this.shouldInvoke.compareAndSet(true, false);
                    this.invokeBlocking();
                    continue;
                }
                ThreadUtils.uncheckedSleep((long)1000L);
            }
        }

        public void invoke() {
            this.shouldInvoke.compareAndSet(false, true);
        }

        public synchronized void invokeBlocking() {
            CountMap localMap;
            if (AbstractCoOccurrences.this.getMemoryFootprint() < AbstractCoOccurrences.this.getMemoryThreshold() && !this.isFinished.get()) {
                return;
            }
            int numberOfLinesSaved = 0;
            this.isInvoked.set(true);
            logger.debug("Memory purge started.");
            this.counter.tick();
            try {
                AbstractCoOccurrences.this.lock.writeLock().lock();
                localMap = AbstractCoOccurrences.this.coOccurrenceCounts;
                AbstractCoOccurrences.this.coOccurrenceCounts = new CountMap();
            }
            finally {
                AbstractCoOccurrences.this.lock.writeLock().unlock();
            }
            try {
                CoOccurrenceWriter writer;
                File file = null;
                file = !this.isFinished.get() ? this.tempFiles[this.counter.previous()] : AbstractCoOccurrences.this.targetFile;
                int linesRead = 0;
                logger.debug("Saving to: [" + this.counter.get() + "], Reading from: [" + this.counter.previous() + "]");
                BinaryCoOccurrenceReader reader = new BinaryCoOccurrenceReader(this.tempFiles[this.counter.previous()], AbstractCoOccurrences.this.vocabCache, localMap);
                CoOccurrenceWriter coOccurrenceWriter = writer = this.isFinished.get() ? new ASCIICoOccurrenceWriter(AbstractCoOccurrences.this.targetFile) : new BinaryCoOccurrenceWriter(this.tempFiles[this.counter.get()]);
                while (reader.hasMoreObjects()) {
                    CoOccurrenceWeight line = reader.nextObject();
                    if (line == null) continue;
                    writer.writeObject(line);
                    ++numberOfLinesSaved;
                    ++linesRead;
                }
                reader.finish();
                logger.debug("Lines read: [" + linesRead + "]");
                Iterator iterator = localMap.getPairIterator();
                while (iterator.hasNext()) {
                    Pair pair = iterator.next();
                    double mWeight = localMap.getCount(pair);
                    CoOccurrenceWeight<SequenceElement> object = new CoOccurrenceWeight<SequenceElement>();
                    object.setElement1((SequenceElement)pair.getFirst());
                    object.setElement2((SequenceElement)pair.getSecond());
                    object.setWeight(mWeight);
                    writer.writeObject(object);
                    ++numberOfLinesSaved;
                }
                writer.finish();
                localMap = null;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            logger.info("Number of word pairs saved so far: [" + numberOfLinesSaved + "]");
            this.isInvoked.set(false);
        }

        public void finish() {
            if (this.isFinished.get()) {
                return;
            }
            this.isFinished.set(true);
            this.invokeBlocking();
        }

        public void terminate() {
            this.isTerminate.set(true);
        }
    }

    private class CoOccurrencesCalculatorThread
    extends Thread
    implements Runnable {
        private final SequenceIterator<T> iterator;
        private final AtomicLong sequenceCounter;
        private int threadId;

        public CoOccurrencesCalculatorThread(@NonNull int threadId, @NonNull SequenceIterator<T> iterator, AtomicLong sequenceCounter) {
            if (iterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            if (sequenceCounter == null) {
                throw new NullPointerException("sequenceCounter is marked @NonNull but is null");
            }
            this.iterator = iterator;
            this.sequenceCounter = sequenceCounter;
            this.threadId = threadId;
            this.setName("CoOccurrencesCalculatorThread " + threadId);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void run() {
            while (this.iterator.hasMoreSequences()) {
                Sequence sequence = this.iterator.nextSequence();
                ArrayList<String> tokens = new ArrayList<String>(sequence.asLabels());
                for (int x = 0; x < sequence.getElements().size(); ++x) {
                    int wordIdx = AbstractCoOccurrences.this.vocabCache.indexOf((String)tokens.get(x));
                    if (wordIdx < 0) continue;
                    String w1 = ((SequenceElement)AbstractCoOccurrences.this.vocabCache.wordFor((String)tokens.get(x))).getLabel();
                    int windowStop = Math.min(x + AbstractCoOccurrences.this.windowSize + 1, tokens.size());
                    for (int j = x; j < windowStop; ++j) {
                        String w2;
                        int otherWord = AbstractCoOccurrences.this.vocabCache.indexOf((String)tokens.get(j));
                        if (otherWord < 0 || (w2 = ((SequenceElement)AbstractCoOccurrences.this.vocabCache.wordFor((String)tokens.get(j))).getLabel()).equals("UNK") || otherWord == wordIdx) continue;
                        Object tokenX = AbstractCoOccurrences.this.vocabCache.wordFor((String)tokens.get(x));
                        Object tokenJ = AbstractCoOccurrences.this.vocabCache.wordFor((String)tokens.get(j));
                        double nWeight = 1.0 / ((double)(j - x) + Nd4j.EPS_THRESHOLD);
                        while (AbstractCoOccurrences.this.getMemoryFootprint() >= AbstractCoOccurrences.this.getMemoryThreshold()) {
                            AbstractCoOccurrences.this.shadowThread.invoke();
                            if (this.threadId == 0) {
                                logger.debug("Memory consuimption > threshold: {footrpint: [" + AbstractCoOccurrences.this.getMemoryFootprint() + "], threshold: [" + AbstractCoOccurrences.this.getMemoryThreshold() + "] }");
                            }
                            ThreadUtils.uncheckedSleep((long)10000L);
                        }
                        try {
                            AbstractCoOccurrences.this.lock.readLock().lock();
                            if (wordIdx < otherWord) {
                                AbstractCoOccurrences.this.coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight);
                                if (!AbstractCoOccurrences.this.symmetric) continue;
                                AbstractCoOccurrences.this.coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight);
                                continue;
                            }
                            AbstractCoOccurrences.this.coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight);
                            if (!AbstractCoOccurrences.this.symmetric) continue;
                            AbstractCoOccurrences.this.coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight);
                            continue;
                        }
                        finally {
                            AbstractCoOccurrences.this.lock.readLock().unlock();
                        }
                    }
                }
                this.sequenceCounter.incrementAndGet();
            }
        }
    }

    public static class Builder<T extends SequenceElement> {
        protected boolean symmetric;
        protected int windowSize = 5;
        protected VocabCache<T> vocabCache;
        protected SequenceIterator<T> sequenceIterator;
        protected int workers = Runtime.getRuntime().availableProcessors();
        protected File target;
        protected long maxmemory = Runtime.getRuntime().maxMemory();

        public Builder<T> symmetric(boolean reallySymmetric) {
            this.symmetric = reallySymmetric;
            return this;
        }

        public Builder<T> windowSize(int windowSize) {
            this.windowSize = windowSize;
            return this;
        }

        public Builder<T> vocabCache(@NonNull VocabCache<T> cache) {
            if (cache == null) {
                throw new NullPointerException("cache is marked @NonNull but is null");
            }
            this.vocabCache = cache;
            return this;
        }

        public Builder<T> iterate(@NonNull SequenceIterator<T> iterator) {
            if (iterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            this.sequenceIterator = new SynchronizedSequenceIterator<T>(iterator);
            return this;
        }

        public Builder<T> workers(int numWorkers) {
            this.workers = numWorkers;
            return this;
        }

        public Builder<T> maxMemory(int gbytes) {
            if (gbytes > 0) {
                this.maxmemory = (long)(Math.max(gbytes - 1, 1) * 1024 * 1024) * 1024L;
            }
            return this;
        }

        public Builder<T> targetFile(@NonNull String path) {
            if (path == null) {
                throw new NullPointerException("path is marked @NonNull but is null");
            }
            this.targetFile(new File(path));
            return this;
        }

        public Builder<T> targetFile(@NonNull File file) {
            if (file == null) {
                throw new NullPointerException("file is marked @NonNull but is null");
            }
            this.target = file;
            return this;
        }

        public AbstractCoOccurrences<T> build() {
            AbstractCoOccurrences ret = new AbstractCoOccurrences();
            ret.sequenceIterator = this.sequenceIterator;
            ret.windowSize = this.windowSize;
            ret.vocabCache = this.vocabCache;
            ret.symmetric = this.symmetric;
            ret.workers = this.workers;
            if (this.maxmemory < 1L) {
                this.maxmemory = Runtime.getRuntime().maxMemory();
            }
            ret.memory_threshold = this.maxmemory;
            logger.info("Actual memory limit: [" + this.maxmemory + "]");
            try {
                if (this.target == null) {
                    this.target = DL4JFileUtils.createTempFile((String)"cooccurrence", (String)"map");
                }
                this.target.deleteOnExit();
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            ret.targetFile = this.target;
            return ret;
        }
    }
}

