/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.datavec;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.NoSuchElementException;
import org.datavec.api.records.SequenceRecord;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataComposable;
import org.datavec.api.records.metadata.RecordMetaDataComposableMap;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.deeplearning4j.datasets.datavec.exception.ZeroLengthSequenceException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class SequenceRecordReaderDataSetIterator
implements DataSetIterator {
    private static final String READER_KEY = "reader";
    private static final String READER_KEY_LABEL = "reader_labels";
    private SequenceRecordReader recordReader;
    private SequenceRecordReader labelsReader;
    private int miniBatchSize = 10;
    private final boolean regression;
    private int labelIndex = -1;
    private final int numPossibleLabels;
    private int cursor = 0;
    private int inputColumns = -1;
    private int totalOutcomes = -1;
    private boolean useStored = false;
    private DataSet stored = null;
    private DataSetPreProcessor preProcessor;
    private AlignmentMode alignmentMode;
    private final boolean singleSequenceReaderMode;
    private boolean collectMetaData = false;
    private RecordReaderMultiDataSetIterator underlying;
    private boolean underlyingIsDisjoint;

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader featuresReader, SequenceRecordReader labels, int miniBatchSize, int numPossibleLabels) {
        this(featuresReader, labels, miniBatchSize, numPossibleLabels, false);
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader featuresReader, SequenceRecordReader labels, int miniBatchSize, int numPossibleLabels, boolean regression) {
        this(featuresReader, labels, miniBatchSize, numPossibleLabels, regression, AlignmentMode.EQUAL_LENGTH);
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader featuresReader, SequenceRecordReader labels, int miniBatchSize, int numPossibleLabels, boolean regression, AlignmentMode alignmentMode) {
        this.recordReader = featuresReader;
        this.labelsReader = labels;
        this.miniBatchSize = miniBatchSize;
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
        this.alignmentMode = alignmentMode;
        this.singleSequenceReaderMode = false;
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader reader, int miniBatchSize, int numPossibleLabels, int labelIndex) {
        this(reader, miniBatchSize, numPossibleLabels, labelIndex, false);
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader reader, int miniBatchSize, int numPossibleLabels, int labelIndex, boolean regression) {
        this.recordReader = reader;
        this.labelsReader = null;
        this.miniBatchSize = miniBatchSize;
        this.regression = regression;
        this.labelIndex = labelIndex;
        this.numPossibleLabels = numPossibleLabels;
        this.singleSequenceReaderMode = true;
    }

    private void initializeUnderlyingFromReader() {
        this.initializeUnderlying(this.recordReader.nextSequence());
        this.underlying.reset();
    }

    private void initializeUnderlying(SequenceRecord nextF) {
        if (nextF.getSequenceRecord().isEmpty()) {
            throw new ZeroLengthSequenceException();
        }
        int totalSizeF = ((List)nextF.getSequenceRecord().get(0)).size();
        if (this.singleSequenceReaderMode && this.numPossibleLabels >= 1 && this.labelIndex < 0) {
            this.labelIndex = totalSizeF - 1;
        } else if (!this.singleSequenceReaderMode && this.numPossibleLabels >= 1 && this.labelIndex < 0) {
            this.labelIndex = 0;
        }
        this.recordReader.reset();
        RecordReaderMultiDataSetIterator.Builder builder = new RecordReaderMultiDataSetIterator.Builder(this.miniBatchSize);
        builder.addSequenceReader(READER_KEY, this.recordReader);
        if (this.labelsReader != null) {
            builder.addSequenceReader(READER_KEY_LABEL, this.labelsReader);
        }
        if (this.singleSequenceReaderMode) {
            if (this.labelIndex == 0 || this.labelIndex == totalSizeF - 1) {
                int inputTo;
                int inputFrom;
                if (this.labelIndex < 0) {
                    inputFrom = 0;
                    inputTo = totalSizeF - 1;
                } else if (this.labelIndex == 0) {
                    inputFrom = 1;
                    inputTo = totalSizeF - 1;
                } else {
                    inputFrom = 0;
                    inputTo = this.labelIndex - 1;
                }
                builder.addInput(READER_KEY, inputFrom, inputTo);
                this.underlyingIsDisjoint = false;
            } else if (this.regression && this.numPossibleLabels > 1) {
                int inputFrom = 0;
                int inputTo = this.labelIndex - 1;
                int outputFrom = this.labelIndex;
                int outputTo = totalSizeF - 1;
                builder.addInput(READER_KEY, inputFrom, inputTo);
                builder.addOutput(READER_KEY, outputFrom, outputTo);
                this.underlyingIsDisjoint = false;
            } else {
                int firstFrom = 0;
                int firstTo = this.labelIndex - 1;
                int secondFrom = this.labelIndex + 1;
                int secondTo = totalSizeF - 1;
                builder.addInput(READER_KEY, firstFrom, firstTo);
                builder.addInput(READER_KEY, secondFrom, secondTo);
                this.underlyingIsDisjoint = true;
            }
            if (this.regression && this.numPossibleLabels <= 1) {
                builder.addOutput(READER_KEY, this.labelIndex, this.labelIndex);
            } else if (!this.regression) {
                builder.addOutputOneHot(READER_KEY, this.labelIndex, this.numPossibleLabels);
            }
        } else {
            builder.addInput(READER_KEY);
            this.underlyingIsDisjoint = false;
            if (this.regression) {
                builder.addOutput(READER_KEY_LABEL);
            } else {
                builder.addOutputOneHot(READER_KEY_LABEL, 0, this.numPossibleLabels);
            }
        }
        if (this.alignmentMode != null) {
            switch (this.alignmentMode) {
                case EQUAL_LENGTH: {
                    builder.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.EQUAL_LENGTH);
                    break;
                }
                case ALIGN_START: {
                    builder.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START);
                    break;
                }
                case ALIGN_END: {
                    builder.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END);
                }
            }
        }
        this.underlying = builder.build();
        if (this.collectMetaData) {
            this.underlying.setCollectMetaData(true);
        }
    }

    private DataSet mdsToDataSet(MultiDataSet mds) {
        INDArray f;
        INDArray fm;
        if (this.underlyingIsDisjoint) {
            INDArray f1 = RecordReaderDataSetIterator.getOrNull(mds.getFeatures(), 0);
            INDArray f2 = RecordReaderDataSetIterator.getOrNull(mds.getFeatures(), 1);
            fm = RecordReaderDataSetIterator.getOrNull(mds.getFeaturesMaskArrays(), 0);
            f = Nd4j.createUninitialized((int[])new int[]{f1.size(0), f1.size(1) + f2.size(1), f1.size(2)});
            f.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)f1.size(1)), NDArrayIndex.all()}, f1);
            f.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)f1.size(1), (int)(f1.size(1) + f2.size(1))), NDArrayIndex.all()}, f2);
        } else {
            f = RecordReaderDataSetIterator.getOrNull(mds.getFeatures(), 0);
            fm = RecordReaderDataSetIterator.getOrNull(mds.getFeaturesMaskArrays(), 0);
        }
        INDArray l = RecordReaderDataSetIterator.getOrNull(mds.getLabels(), 0);
        INDArray lm = RecordReaderDataSetIterator.getOrNull(mds.getLabelsMaskArrays(), 0);
        DataSet ds = new DataSet(f, l, fm, lm);
        if (this.collectMetaData) {
            List temp = mds.getExampleMetaData();
            ArrayList<Object> temp2 = new ArrayList<Object>(temp.size());
            for (Serializable s : temp) {
                RecordMetaDataComposableMap m = (RecordMetaDataComposableMap)s;
                if (this.singleSequenceReaderMode) {
                    temp2.add(m.getMeta().get(READER_KEY));
                    continue;
                }
                RecordMetaDataComposable c = new RecordMetaDataComposable(new RecordMetaData[]{(RecordMetaData)m.getMeta().get(READER_KEY), (RecordMetaData)m.getMeta().get(READER_KEY_LABEL)});
                temp2.add(c);
            }
            ds.setExampleMetaData(temp2);
        }
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.DataSet)ds);
        }
        return ds;
    }

    public boolean hasNext() {
        if (this.underlying == null) {
            this.initializeUnderlyingFromReader();
        }
        return this.underlying.hasNext();
    }

    public DataSet next() {
        return this.next(this.miniBatchSize);
    }

    public DataSet next(int num) {
        if (this.useStored) {
            this.useStored = false;
            DataSet temp = this.stored;
            this.stored = null;
            if (this.preProcessor != null) {
                this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.DataSet)temp);
            }
            return temp;
        }
        if (!this.hasNext()) {
            throw new NoSuchElementException();
        }
        if (this.underlying == null) {
            this.initializeUnderlyingFromReader();
        }
        MultiDataSet mds = this.underlying.next(num);
        DataSet ds = this.mdsToDataSet(mds);
        if (this.totalOutcomes == -1) {
            this.inputColumns = ds.getFeatures().size(1);
            this.totalOutcomes = ds.getLabels().size(1);
        }
        return ds;
    }

    public int totalExamples() {
        throw new UnsupportedOperationException("Not supported");
    }

    public int inputColumns() {
        if (this.inputColumns != -1) {
            return this.inputColumns;
        }
        this.preLoad();
        return this.inputColumns;
    }

    public int totalOutcomes() {
        if (this.totalOutcomes != -1) {
            return this.totalOutcomes;
        }
        this.preLoad();
        return this.totalOutcomes;
    }

    private void preLoad() {
        this.stored = this.next();
        this.useStored = true;
        this.inputColumns = this.stored.getFeatureMatrix().size(1);
        this.totalOutcomes = this.stored.getLabels().size(1);
    }

    public boolean resetSupported() {
        return true;
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        if (this.underlying != null) {
            this.underlying.reset();
        }
        this.cursor = 0;
        this.stored = null;
        this.useStored = false;
    }

    public int batch() {
        return this.miniBatchSize;
    }

    public int cursor() {
        return this.cursor;
    }

    public int numExamples() {
        throw new UnsupportedOperationException("Not supported");
    }

    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    public List<String> getLabels() {
        return null;
    }

    public void remove() {
        throw new UnsupportedOperationException("Remove not supported for this iterator");
    }

    public DataSet loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return this.loadFromMetaData(Collections.singletonList(recordMetaData));
    }

    public DataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
        if (this.underlying == null) {
            SequenceRecord r = this.recordReader.loadSequenceFromMetaData(list.get(0));
            this.initializeUnderlying(r);
        }
        ArrayList<RecordMetaData> l = new ArrayList<RecordMetaData>(list.size());
        if (this.singleSequenceReaderMode) {
            for (RecordMetaData m : list) {
                l.add((RecordMetaData)new RecordMetaDataComposableMap(Collections.singletonMap(READER_KEY, m)));
            }
        } else {
            for (RecordMetaData m : list) {
                RecordMetaDataComposable rmdc = (RecordMetaDataComposable)m;
                HashMap<String, RecordMetaData> map = new HashMap<String, RecordMetaData>(2);
                map.put(READER_KEY, rmdc.getMeta()[0]);
                map.put(READER_KEY_LABEL, rmdc.getMeta()[1]);
                l.add((RecordMetaData)new RecordMetaDataComposableMap(map));
            }
        }
        return this.mdsToDataSet(this.underlying.loadFromMetaData(l));
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public boolean isCollectMetaData() {
        return this.collectMetaData;
    }

    public void setCollectMetaData(boolean collectMetaData) {
        this.collectMetaData = collectMetaData;
    }

    public static enum AlignmentMode {
        EQUAL_LENGTH,
        ALIGN_START,
        ALIGN_END;

    }
}

