/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.datavec;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.datavec.api.io.WritableConverter;
import org.datavec.api.io.converters.SelfWritableConverter;
import org.datavec.api.io.converters.WritableConverterException;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.datavec.common.data.NDArrayWritable;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

public class RecordReaderDataSetIterator
implements DataSetIterator {
    protected RecordReader recordReader;
    protected WritableConverter converter;
    protected int batchSize = 10;
    protected int maxNumBatches = -1;
    protected int batchNum = 0;
    protected int labelIndex = -1;
    protected int labelIndexTo = -1;
    protected int numPossibleLabels = -1;
    protected Iterator<List<Writable>> sequenceIter;
    protected DataSet last;
    protected boolean useCurrent = false;
    protected boolean regression = false;
    protected DataSetPreProcessor preProcessor;
    private boolean collectMetaData = false;

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter converter, int batchSize) {
        this(recordReader, converter, batchSize, -1, recordReader.getLabels() == null ? -1 : recordReader.getLabels().size());
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize) {
        this(recordReader, (WritableConverter)new SelfWritableConverter(), batchSize, -1, recordReader.getLabels() == null ? -1 : recordReader.getLabels().size());
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndex, int numPossibleLabels) {
        this(recordReader, (WritableConverter)new SelfWritableConverter(), batchSize, labelIndex, numPossibleLabels);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter converter, int batchSize, int labelIndex, int numPossibleLabels, boolean regression) {
        this(recordReader, converter, batchSize, labelIndex, numPossibleLabels, -1, regression);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter converter, int batchSize, int labelIndex, int numPossibleLabels) {
        this(recordReader, converter, batchSize, labelIndex, numPossibleLabels, -1, false);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndex, int numPossibleLabels, int maxNumBatches) {
        this(recordReader, (WritableConverter)new SelfWritableConverter(), batchSize, labelIndex, numPossibleLabels, maxNumBatches, false);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndexFrom, int labelIndexTo, boolean regression) {
        this(recordReader, (WritableConverter)new SelfWritableConverter(), batchSize, labelIndexFrom, labelIndexTo, -1, -1, regression);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter converter, int batchSize, int labelIndex, int numPossibleLabels, int maxNumBatches, boolean regression) {
        this(recordReader, converter, batchSize, labelIndex, labelIndex, numPossibleLabels, maxNumBatches, regression);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter converter, int batchSize, int labelIndexFrom, int labelIndexTo, int numPossibleLabels, int maxNumBatches, boolean regression) {
        this.recordReader = recordReader;
        this.converter = converter;
        this.batchSize = batchSize;
        this.maxNumBatches = maxNumBatches;
        this.labelIndex = labelIndexFrom;
        this.labelIndexTo = labelIndexTo;
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
    }

    public DataSet next(int num) {
        if (this.useCurrent) {
            this.useCurrent = false;
            if (this.preProcessor != null) {
                this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.DataSet)this.last);
            }
            return this.last;
        }
        ArrayList<DataSet> dataSets = new ArrayList<DataSet>();
        ArrayList<RecordMetaData> meta = this.collectMetaData ? new ArrayList<RecordMetaData>() : null;
        for (int i = 0; i < num && this.hasNext(); ++i) {
            Object record;
            if (this.recordReader instanceof SequenceRecordReader) {
                if (this.sequenceIter == null || !this.sequenceIter.hasNext()) {
                    List sequenceRecord = ((SequenceRecordReader)this.recordReader).sequenceRecord();
                    this.sequenceIter = sequenceRecord.iterator();
                }
                record = this.sequenceIter.next();
                dataSets.add(this.getDataSet((List<Writable>)record));
                continue;
            }
            if (this.collectMetaData) {
                record = this.recordReader.nextRecord();
                dataSets.add(this.getDataSet(record.getRecord()));
                meta.add(record.getMetaData());
                continue;
            }
            record = this.recordReader.next();
            dataSets.add(this.getDataSet((List<Writable>)record));
        }
        ++this.batchNum;
        if (dataSets.isEmpty()) {
            return new DataSet();
        }
        DataSet ret = DataSet.merge(dataSets);
        if (this.collectMetaData) {
            ret.setExampleMetaData(meta);
        }
        this.last = ret;
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.DataSet)ret);
        }
        if (this.recordReader.getLabels() != null) {
            ret.setLabelNames(this.recordReader.getLabels());
        }
        return ret;
    }

    private DataSet getDataSet(List<Writable> record) {
        List<Writable> currList = record instanceof List ? record : new ArrayList<Writable>(record);
        if (this.numPossibleLabels >= 1 && this.labelIndex < 0) {
            this.labelIndex = record.size() - 1;
        }
        INDArray label = null;
        INDArray featureVector = null;
        int featureCount = 0;
        int labelCount = 0;
        if (currList.size() == 2 && currList.get(1) instanceof NDArrayWritable && currList.get(0) instanceof NDArrayWritable && currList.get(0) == currList.get(1)) {
            NDArrayWritable writable = (NDArrayWritable)currList.get(0);
            return new DataSet(writable.get(), writable.get());
        }
        if (currList.size() == 2 && currList.get(0) instanceof NDArrayWritable) {
            label = !this.regression ? FeatureUtil.toOutcomeVector((int)((int)Double.parseDouble(currList.get(1).toString())), (int)this.numPossibleLabels) : Nd4j.scalar((double)Double.parseDouble(currList.get(1).toString()));
            NDArrayWritable ndArrayWritable = (NDArrayWritable)currList.get(0);
            featureVector = ndArrayWritable.get();
            return new DataSet(featureVector, label);
        }
        for (int j = 0; j < currList.size(); ++j) {
            Writable current = currList.get(j);
            if (!(current instanceof NDArrayWritable) && current.toString().isEmpty()) continue;
            if (this.regression && j >= this.labelIndex && j <= this.labelIndexTo) {
                if (label == null) {
                    label = Nd4j.create((int)1, (int)(this.labelIndexTo - this.labelIndex + 1));
                }
                label.putScalar(labelCount++, current.toDouble());
                continue;
            }
            if (this.labelIndex >= 0 && j == this.labelIndex) {
                if (this.converter != null) {
                    try {
                        current = this.converter.convert(current);
                    }
                    catch (WritableConverterException e) {
                        e.printStackTrace();
                    }
                }
                if (this.numPossibleLabels < 1) {
                    throw new IllegalStateException("Number of possible labels invalid, must be >= 1");
                }
                if (this.regression) {
                    label = Nd4j.scalar((double)current.toDouble());
                    continue;
                }
                int curr = current.toInt();
                if (curr < 0 || curr >= this.numPossibleLabels) {
                    throw new DL4JInvalidInputException("Invalid classification data: expect label value (at label index column = " + this.labelIndex + ") to be in range 0 to " + (this.numPossibleLabels - 1) + " inclusive (0 to numClasses-1, with numClasses=" + this.numPossibleLabels + "); got label value of " + current);
                }
                label = FeatureUtil.toOutcomeVector((int)curr, (int)this.numPossibleLabels);
                continue;
            }
            try {
                double value = current.toDouble();
                if (featureVector == null) {
                    if (this.regression && this.labelIndex >= 0) {
                        int nLabels = this.labelIndexTo - this.labelIndex + 1;
                        featureVector = Nd4j.create((int)1, (int)(currList.size() - nLabels));
                    } else {
                        featureVector = Nd4j.create((int)(this.labelIndex >= 0 ? currList.size() - 1 : currList.size()));
                    }
                }
                featureVector.putScalar(featureCount++, value);
                continue;
            }
            catch (UnsupportedOperationException e) {
                if (current instanceof NDArrayWritable) {
                    assert (featureVector == null);
                    featureVector = ((NDArrayWritable)current).get();
                    continue;
                }
                throw e;
            }
        }
        return new DataSet(featureVector, this.labelIndex >= 0 ? label : featureVector);
    }

    public int totalExamples() {
        throw new UnsupportedOperationException();
    }

    public int inputColumns() {
        if (this.last == null) {
            DataSet next;
            this.last = next = this.next();
            this.useCurrent = true;
            return next.numInputs();
        }
        return this.last.numInputs();
    }

    public int totalOutcomes() {
        if (this.last == null) {
            DataSet next;
            this.last = next = this.next();
            this.useCurrent = true;
            return next.numOutcomes();
        }
        return this.last.numOutcomes();
    }

    public boolean resetSupported() {
        return true;
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        this.batchNum = 0;
        this.recordReader.reset();
    }

    public int batch() {
        return this.batchSize;
    }

    public int cursor() {
        throw new UnsupportedOperationException();
    }

    public int numExamples() {
        throw new UnsupportedOperationException();
    }

    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    public boolean hasNext() {
        return this.recordReader.hasNext() && (this.maxNumBatches < 0 || this.batchNum < this.maxNumBatches);
    }

    public DataSet next() {
        return this.next(this.batchSize);
    }

    public void remove() {
        throw new UnsupportedOperationException();
    }

    public List<String> getLabels() {
        return this.recordReader.getLabels();
    }

    public DataSet loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return this.loadFromMetaData(Collections.singletonList(recordMetaData));
    }

    public DataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
        List records = this.recordReader.loadFromMetaData(list);
        ArrayList<DataSet> dataSets = new ArrayList<DataSet>();
        ArrayList<RecordMetaData> meta = new ArrayList<RecordMetaData>();
        for (Record r : records) {
            dataSets.add(this.getDataSet(r.getRecord()));
            meta.add(r.getMetaData());
        }
        if (dataSets.isEmpty()) {
            return new DataSet();
        }
        DataSet ret = DataSet.merge(dataSets);
        ret.setExampleMetaData(meta);
        this.last = ret;
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.DataSet)ret);
        }
        if (this.recordReader.getLabels() != null) {
            ret.setLabelNames(this.recordReader.getLabels());
        }
        return ret;
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public boolean isCollectMetaData() {
        return this.collectMetaData;
    }

    public void setCollectMetaData(boolean collectMetaData) {
        this.collectMetaData = collectMetaData;
    }
}

