/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.iterator;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.NoSuchElementException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class IteratorMultiDataSetIterator
implements MultiDataSetIterator {
    private final Iterator<org.nd4j.linalg.dataset.api.MultiDataSet> iterator;
    private final int batchSize;
    private final LinkedList<org.nd4j.linalg.dataset.api.MultiDataSet> queued;
    private MultiDataSetPreProcessor preProcessor;

    public IteratorMultiDataSetIterator(Iterator<org.nd4j.linalg.dataset.api.MultiDataSet> iterator, int batchSize) {
        this.iterator = iterator;
        this.batchSize = batchSize;
        this.queued = new LinkedList();
    }

    public boolean hasNext() {
        return !this.queued.isEmpty() || this.iterator.hasNext();
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet next() {
        return this.next(this.batchSize);
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet next(int num) {
        int nExamples;
        if (!this.hasNext()) {
            throw new NoSuchElementException();
        }
        ArrayList<Object> list = new ArrayList<Object>();
        for (int countSoFar = 0; (!this.queued.isEmpty() || this.iterator.hasNext()) && countSoFar < this.batchSize; countSoFar += nExamples) {
            org.nd4j.linalg.dataset.api.MultiDataSet next = !this.queued.isEmpty() ? this.queued.removeFirst() : this.iterator.next();
            nExamples = next.getFeatures(0).size(0);
            if (countSoFar + nExamples <= this.batchSize) {
                list.add(next);
                continue;
            }
            int nFeatures = next.numFeatureArrays();
            int nLabels = next.numLabelsArrays();
            INDArray[] fToKeep = new INDArray[nFeatures];
            INDArray[] lToKeep = new INDArray[nLabels];
            INDArray[] fToCache = new INDArray[nFeatures];
            INDArray[] lToCache = new INDArray[nLabels];
            INDArray[] fMaskToKeep = next.getFeaturesMaskArrays() != null ? new INDArray[nFeatures] : null;
            INDArray[] lMaskToKeep = next.getLabelsMaskArrays() != null ? new INDArray[nLabels] : null;
            INDArray[] fMaskToCache = next.getFeaturesMaskArrays() != null ? new INDArray[nFeatures] : null;
            INDArray[] lMaskToCache = next.getLabelsMaskArrays() != null ? new INDArray[nLabels] : null;
            for (int i = 0; i < nFeatures; ++i) {
                INDArray fi = next.getFeatures(i);
                INDArray li = next.getFeatures(i);
                fToKeep[i] = IteratorMultiDataSetIterator.getRange(fi, 0, this.batchSize - countSoFar);
                fToCache[i] = IteratorMultiDataSetIterator.getRange(fi, this.batchSize - countSoFar, nExamples);
                lToKeep[i] = IteratorMultiDataSetIterator.getRange(li, 0, this.batchSize - countSoFar);
                lToCache[i] = IteratorMultiDataSetIterator.getRange(li, this.batchSize - countSoFar, nExamples);
                if (fMaskToKeep != null) {
                    INDArray fmi = next.getFeaturesMaskArray(i);
                    fMaskToKeep[i] = IteratorMultiDataSetIterator.getRange(fmi, 0, this.batchSize - countSoFar);
                    fMaskToCache[i] = IteratorMultiDataSetIterator.getRange(fmi, this.batchSize - countSoFar, nExamples);
                }
                if (lMaskToKeep == null) continue;
                INDArray lmi = next.getLabelsMaskArray(i);
                lMaskToKeep[i] = IteratorMultiDataSetIterator.getRange(lmi, 0, this.batchSize - countSoFar);
                lMaskToCache[i] = IteratorMultiDataSetIterator.getRange(lmi, this.batchSize - countSoFar, nExamples);
            }
            MultiDataSet toKeep = new MultiDataSet(fToKeep, lToKeep, fMaskToKeep, lMaskToKeep);
            MultiDataSet toCache = new MultiDataSet(fToCache, lToCache, fMaskToCache, lMaskToCache);
            list.add(toKeep);
            this.queued.add((org.nd4j.linalg.dataset.api.MultiDataSet)toCache);
        }
        Object out = list.size() == 1 ? (org.nd4j.linalg.dataset.api.MultiDataSet)list.get(0) : MultiDataSet.merge(list);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.MultiDataSet)out);
        }
        return out;
    }

    private static INDArray getRange(INDArray arr, int exampleFrom, int exampleToExclusive) {
        if (arr == null) {
            return null;
        }
        int rank = arr.rank();
        switch (rank) {
            case 2: {
                return arr.get(new INDArrayIndex[]{NDArrayIndex.interval((int)exampleFrom, (int)exampleToExclusive), NDArrayIndex.all()});
            }
            case 3: {
                return arr.get(new INDArrayIndex[]{NDArrayIndex.interval((int)exampleFrom, (int)exampleToExclusive), NDArrayIndex.all(), NDArrayIndex.all()});
            }
            case 4: {
                return arr.get(new INDArrayIndex[]{NDArrayIndex.interval((int)exampleFrom, (int)exampleToExclusive), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()});
            }
        }
        throw new RuntimeException("Invalid rank: " + rank);
    }

    public void reset() {
        throw new UnsupportedOperationException("Reset not supported");
    }

    public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    public void remove() {
        throw new UnsupportedOperationException("Not supported");
    }
}

