/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.word2vec.iterator;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Pattern;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.datasets.iterator.DataSetFetcher;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.movingwindow.Window;
import org.deeplearning4j.text.movingwindow.WindowConverter;
import org.deeplearning4j.text.movingwindow.Windows;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Word2VecDataFetcher
implements DataSetFetcher {
    private static final long serialVersionUID = 3245955804749769475L;
    private transient Iterator<File> files;
    private Word2Vec vec;
    private static Pattern begin = Pattern.compile("<[A-Z]+>");
    private static Pattern end = Pattern.compile("</[A-Z]+>");
    private List<String> labels = new ArrayList<String>();
    private int batch;
    private List<Window> cache = new ArrayList<Window>();
    private static Logger log = LoggerFactory.getLogger(Word2VecDataFetcher.class);
    private int totalExamples;
    private String path;

    public Word2VecDataFetcher(String path, Word2Vec vec, List<String> labels) {
        if (vec == null || labels == null || labels.isEmpty()) {
            throw new IllegalArgumentException("Unable to initialize due to missing argument or empty label applyTransformToDestination");
        }
        this.vec = vec;
        this.labels = labels;
        this.path = path;
    }

    private DataSet fromCache() {
        INDArray outcomes = null;
        INDArray input = null;
        input = Nd4j.create((int)this.batch, (int)(this.vec.getLayerSize() * this.vec.getWindow()));
        outcomes = Nd4j.create((int)this.batch, (int)this.labels.size());
        for (int i = 0; i < this.batch; ++i) {
            input.putRow(i, Nd4j.create((double[])WindowConverter.asExample(this.cache.get(i), this.vec)));
            int idx = this.labels.indexOf(this.cache.get(i).getLabel());
            if (idx < 0) {
                idx = 0;
            }
            outcomes.putRow(i, FeatureUtil.toOutcomeVector((int)idx, (int)this.labels.size()));
        }
        return new DataSet(input, outcomes);
    }

    public DataSet next() {
        if (this.cache.size() >= this.batch || !this.files.hasNext()) {
            return this.fromCache();
        }
        File f = this.files.next();
        try {
            LineIterator lines = FileUtils.lineIterator((File)f);
            INDArray outcomes = null;
            INDArray input = null;
            while (lines.hasNext()) {
                List<Window> windows = Windows.windows(lines.nextLine());
                if (windows.isEmpty() && lines.hasNext()) continue;
                if (windows.size() < this.batch) {
                    input = Nd4j.create((int)windows.size(), (int)(this.vec.getLayerSize() * this.vec.getWindow()));
                    outcomes = Nd4j.create((int)this.batch, (int)this.labels.size());
                    for (int i = 0; i < windows.size(); ++i) {
                        input.putRow(i, Nd4j.create((double[])WindowConverter.asExample(windows.get(i), this.vec)));
                        int idx = this.labels.indexOf(windows.get(i).getLabel());
                        if (idx < 0) {
                            idx = 0;
                        }
                        INDArray outcomeRow = FeatureUtil.toOutcomeVector((int)idx, (int)this.labels.size());
                        outcomes.putRow(i, outcomeRow);
                    }
                    return new DataSet(input, outcomes);
                }
                input = Nd4j.create((int)this.batch, (int)(this.vec.getLayerSize() * this.vec.getWindow()));
                outcomes = Nd4j.create((int)this.batch, (int)this.labels.size());
                for (int i = 0; i < this.batch; ++i) {
                    input.putRow(i, Nd4j.create((double[])WindowConverter.asExample(windows.get(i), this.vec)));
                    int idx = this.labels.indexOf(windows.get(i).getLabel());
                    if (idx < 0) {
                        idx = 0;
                    }
                    INDArray outcomeRow = FeatureUtil.toOutcomeVector((int)idx, (int)this.labels.size());
                    outcomes.putRow(i, outcomeRow);
                }
                if (windows.size() > this.batch) {
                    List<Window> leftOvers = windows.subList(this.batch, windows.size());
                    this.cache.addAll(leftOvers);
                }
                return new DataSet(input, outcomes);
            }
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        return null;
    }

    public int totalExamples() {
        return this.totalExamples;
    }

    public int inputColumns() {
        return this.vec.getLayerSize() * this.vec.getWindow();
    }

    public int totalOutcomes() {
        return this.labels.size();
    }

    public void reset() {
        this.files = FileUtils.iterateFiles((File)new File(this.path), null, (boolean)true);
        this.cache.clear();
    }

    public int cursor() {
        return 0;
    }

    public boolean hasMore() {
        return this.files.hasNext() || !this.cache.isEmpty();
    }

    public void fetch(int numExamples) {
        this.batch = numExamples;
    }

    public Iterator<File> getFiles() {
        return this.files;
    }

    public Word2Vec getVec() {
        return this.vec;
    }

    public static Pattern getBegin() {
        return begin;
    }

    public static Pattern getEnd() {
        return end;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public int getBatch() {
        return this.batch;
    }

    public List<Window> getCache() {
        return this.cache;
    }
}

