/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.word2vec.iterator;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.DataSetPreProcessor;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.inputsanitation.InputHomogenization;
import org.deeplearning4j.text.movingwindow.Window;
import org.deeplearning4j.text.movingwindow.WindowConverter;
import org.deeplearning4j.text.movingwindow.Windows;
import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

public class Word2VecDataSetIterator
implements DataSetIterator {
    private Word2Vec vec;
    private LabelAwareSentenceIterator iter;
    private List<Window> cachedWindow;
    private List<String> labels;
    private boolean homogenization = true;
    private boolean addLabels = true;
    private int batch = 10;
    private org.nd4j.linalg.dataset.DataSet curr;
    private DataSetPreProcessor preProcessor;

    public Word2VecDataSetIterator(Word2Vec vec, LabelAwareSentenceIterator iter, List<String> labels, int batch, boolean homogenization, boolean addLabels) {
        this.vec = vec;
        this.iter = iter;
        this.labels = labels;
        this.batch = batch;
        this.cachedWindow = new CopyOnWriteArrayList<Window>();
        this.addLabels = addLabels;
        this.homogenization = homogenization;
        if (addLabels && homogenization) {
            iter.setPreProcessor(new SentencePreProcessor(){

                @Override
                public String preProcess(String sentence) {
                    String label = Word2VecDataSetIterator.this.iter.currentLabel();
                    String ret = "<" + label + "> " + new InputHomogenization(sentence).transform() + " </" + label + ">";
                    return ret;
                }
            });
        } else if (addLabels) {
            iter.setPreProcessor(new SentencePreProcessor(){

                @Override
                public String preProcess(String sentence) {
                    String label = Word2VecDataSetIterator.this.iter.currentLabel();
                    String ret = "<" + label + ">" + sentence + "</" + label + ">";
                    return ret;
                }
            });
        } else if (homogenization) {
            iter.setPreProcessor(new SentencePreProcessor(){

                @Override
                public String preProcess(String sentence) {
                    String ret = new InputHomogenization(sentence).transform();
                    return ret;
                }
            });
        }
    }

    public Word2VecDataSetIterator(Word2Vec vec, LabelAwareSentenceIterator iter, List<String> labels) {
        this(vec, iter, labels, 10);
    }

    public Word2VecDataSetIterator(Word2Vec vec, LabelAwareSentenceIterator iter, List<String> labels, int batch) {
        this(vec, iter, labels, batch, true, true);
    }

    public org.nd4j.linalg.dataset.DataSet next(int num) {
        if (num <= this.cachedWindow.size()) {
            return this.fromCached(num);
        }
        if (num >= this.cachedWindow.size() && !this.iter.hasNext()) {
            return this.fromCached(this.cachedWindow.size());
        }
        while (this.cachedWindow.size() < num && this.iter.hasNext()) {
            String sentence = this.iter.nextSentence();
            if (sentence.isEmpty()) continue;
            List<Window> windows = Windows.windows(sentence, this.vec.getTokenizerFactory(), this.vec.getWindow());
            if (windows.isEmpty() && !sentence.isEmpty()) {
                throw new IllegalStateException("Empty window on sentence");
            }
            for (Window w : windows) {
                w.setLabel(this.iter.currentLabel());
            }
            this.cachedWindow.addAll(windows);
        }
        return this.fromCached(num);
    }

    private org.nd4j.linalg.dataset.DataSet fromCached(int num) {
        if (this.cachedWindow.isEmpty()) {
            while (this.cachedWindow.size() < num && this.iter.hasNext()) {
                String sentence = this.iter.nextSentence();
                if (sentence.isEmpty()) continue;
                List<Window> windows = Windows.windows(sentence, this.vec.getTokenizerFactory(), this.vec.getWindow());
                for (Window w : windows) {
                    w.setLabel(this.iter.currentLabel());
                }
                this.cachedWindow.addAll(windows);
            }
        }
        ArrayList<Window> windows = new ArrayList<Window>(num);
        for (int i = 0; i < num && !this.cachedWindow.isEmpty(); ++i) {
            windows.add(this.cachedWindow.remove(0));
        }
        if (windows.isEmpty()) {
            return null;
        }
        INDArray inputs = Nd4j.create((int)num, (int)this.inputColumns());
        for (int i = 0; i < inputs.rows(); ++i) {
            inputs.putRow(i, WindowConverter.asExampleMatrix((Window)windows.get(i), this.vec));
        }
        INDArray labelOutput = Nd4j.create((int)num, (int)this.labels.size());
        for (int i = 0; i < labelOutput.rows(); ++i) {
            String label = ((Window)windows.get(i)).getLabel();
            labelOutput.putRow(i, FeatureUtil.toOutcomeVector((int)this.labels.indexOf(label), (int)this.labels.size()));
        }
        org.nd4j.linalg.dataset.DataSet ret = new org.nd4j.linalg.dataset.DataSet(inputs, labelOutput);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((DataSet)ret);
        }
        return ret;
    }

    public int totalExamples() {
        throw new UnsupportedOperationException();
    }

    public int inputColumns() {
        return this.vec.getLayerSize() * this.vec.getWindow();
    }

    public int totalOutcomes() {
        return this.labels.size();
    }

    public void reset() {
        this.iter.reset();
        this.cachedWindow.clear();
    }

    public int batch() {
        return this.batch;
    }

    public int cursor() {
        return 0;
    }

    public int numExamples() {
        return 0;
    }

    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    public boolean hasNext() {
        return this.iter.hasNext() || !this.cachedWindow.isEmpty();
    }

    public org.nd4j.linalg.dataset.DataSet next() {
        return this.next(this.batch);
    }

    public void remove() {
        throw new UnsupportedOperationException();
    }
}

