/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.dataset;

import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Sampler;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public class BatchSampler
implements Sampler {
    private Sampler.SubSampler subSampler;
    private long batchSize;
    private boolean dropLast;

    public BatchSampler(Sampler.SubSampler subSampler, long batchSize) {
        this(subSampler, batchSize, false);
    }

    public BatchSampler(Sampler.SubSampler subSampler, long batchSize, boolean dropLast) {
        this.subSampler = subSampler;
        this.batchSize = batchSize;
        this.dropLast = dropLast;
    }

    @Override
    public Iterator<List<Long>> sample(RandomAccessDataset dataset) {
        return new Iterate(dataset);
    }

    class Iterate
    implements Iterator<List<Long>> {
        private long size;
        private long current = 0L;
        private Iterator<Long> itemSampler;

        Iterate(RandomAccessDataset dataset) {
            this.size = BatchSampler.this.dropLast ? dataset.size() / BatchSampler.this.batchSize : (dataset.size() + BatchSampler.this.batchSize - 1L) / BatchSampler.this.batchSize;
            this.itemSampler = BatchSampler.this.subSampler.sample(dataset);
        }

        @Override
        public boolean hasNext() {
            return this.current < this.size;
        }

        @Override
        public List<Long> next() {
            ArrayList<Long> batchIndices = new ArrayList<Long>();
            while (this.itemSampler.hasNext()) {
                batchIndices.add(this.itemSampler.next());
                if ((long)batchIndices.size() != BatchSampler.this.batchSize) continue;
            }
            ++this.current;
            return batchIndices;
        }
    }
}

