/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.sequencevectors.transformers.impl.iterables;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.text.documentiterator.AsyncLabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelTransformerIterator
extends BasicTransformerIterator {
    private static final Logger log = LoggerFactory.getLogger(ParallelTransformerIterator.class);
    protected static final int capacity = 1024;
    protected BlockingQueue<Future<Sequence<VocabWord>>> buffer = new LinkedBlockingQueue<Future<Sequence<VocabWord>>>(1024);
    protected AtomicBoolean underlyingHas = new AtomicBoolean(true);
    protected AtomicInteger processing = new AtomicInteger(0);
    private ExecutorService executorService;
    protected static final AtomicInteger count = new AtomicInteger(0);
    private static final int PREFETCH_SIZE = 100;

    public ParallelTransformerIterator(@NonNull LabelAwareIterator iterator, @NonNull SentenceTransformer transformer) {
        this(iterator, transformer, true);
        if (iterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (transformer == null) {
            throw new NullPointerException("transformer is marked non-null but is null");
        }
    }

    private void prefetchIterator() {
    }

    public ParallelTransformerIterator(@NonNull LabelAwareIterator iterator, @NonNull SentenceTransformer transformer, boolean allowMultithreading) {
        super(new AsyncLabelAwareIterator(iterator, 512), transformer);
        if (iterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (transformer == null) {
            throw new NullPointerException("transformer is marked non-null but is null");
        }
        this.allowMultithreading = allowMultithreading;
        this.executorService = Executors.newFixedThreadPool(allowMultithreading ? Math.max(Runtime.getRuntime().availableProcessors(), 2) : 1);
        this.prefetchIterator();
    }

    @Override
    public void reset() {
        this.executorService.shutdownNow();
        this.iterator.reset();
        this.underlyingHas.set(true);
        this.prefetchIterator();
        this.buffer.clear();
        this.executorService = Executors.newFixedThreadPool(this.allowMultithreading ? Math.max(Runtime.getRuntime().availableProcessors(), 2) : 1);
    }

    public void shutdown() {
        this.executorService.shutdown();
    }

    @Override
    public boolean hasNext() {
        if (this.buffer.size() < 1024 && this.iterator.hasNextDocument()) {
            CallableTransformer transformer = new CallableTransformer(this.iterator.nextDocument(), this.sentenceTransformer);
            Future<Sequence<VocabWord>> futureSequence = this.executorService.submit(transformer);
            try {
                this.buffer.put(futureSequence);
            }
            catch (InterruptedException e) {
                log.error("", (Throwable)e);
            }
        }
        return !this.buffer.isEmpty() || this.processing.get() > 0;
    }

    @Override
    public Sequence<VocabWord> next() {
        try {
            this.processing.incrementAndGet();
            Future<Sequence<VocabWord>> future = this.buffer.take();
            Sequence<VocabWord> sequence = future.get();
            this.processing.decrementAndGet();
            return sequence;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static class CallableTransformer
    implements Callable<Sequence<VocabWord>> {
        private LabelledDocument document;
        private SentenceTransformer transformer;

        public CallableTransformer(LabelledDocument document, SentenceTransformer transformer) {
            this.transformer = transformer;
            this.document = document;
        }

        @Override
        public Sequence<VocabWord> call() {
            Sequence<VocabWord> sequence = new Sequence<VocabWord>();
            if (this.document != null && this.document.getContent() != null) {
                sequence = this.transformer.transformToSequence(this.document.getContent());
                if (this.document.getLabels() != null) {
                    for (String label : this.document.getLabels()) {
                        if (label == null || label.isEmpty()) continue;
                        sequence.addSequenceLabel(new VocabWord(1.0, label));
                    }
                }
            }
            return sequence;
        }
    }
}

