/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset.utils;

import ai.djl.basicdataset.TextDataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Sampler;
import ai.djl.util.RandomUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FixedBucketSampler
implements Sampler {
    private static final Logger logger = LoggerFactory.getLogger(FixedBucketSampler.class);
    private int numBuckets;
    private int batchSize;
    private boolean shuffle;

    public FixedBucketSampler(int batchSize, int numBuckets, boolean shuffle) {
        this.numBuckets = numBuckets;
        this.batchSize = batchSize;
        this.shuffle = shuffle;
        if (batchSize == 1) {
            logger.warn("FixedBucketSampler is not meaningful with batch size 1.");
        }
    }

    public FixedBucketSampler(int batchSize, int numBuckets) {
        this(batchSize, numBuckets, true);
    }

    public FixedBucketSampler(int batchSize) {
        this(batchSize, 10);
    }

    public Iterator<List<Long>> sample(RandomAccessDataset dataset) {
        if (!(dataset instanceof TextDataset)) {
            throw new IllegalArgumentException("FixedBucketSampler can only be used with TextDataset");
        }
        return new Iterate((TextDataset)dataset);
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    private class Iterate
    implements Iterator<List<Long>> {
        private List<List<TextDataset.Sample>> buckets;
        private List<int[]> bucketBatch;
        private int current;

        public Iterate(TextDataset dataset) {
            this.buckets = new ArrayList<List<TextDataset.Sample>>(FixedBucketSampler.this.numBuckets);
            this.bucketBatch = new ArrayList<int[]>();
            List<TextDataset.Sample> samples = dataset.getSamples();
            int min = samples.get(0).getSentenceLength();
            int max = samples.get(samples.size() - 1).getSentenceLength();
            int step = Math.max((1 + max - min) / FixedBucketSampler.this.numBuckets, 1);
            HashSet<Integer> set = new HashSet<Integer>(FixedBucketSampler.this.numBuckets);
            for (int i = 0; i < FixedBucketSampler.this.numBuckets; ++i) {
                set.add(Math.max(max - (FixedBucketSampler.this.numBuckets - i - 1) * step, min));
            }
            int[] bucketKeys = set.stream().mapToInt(Integer::intValue).toArray();
            int index = 0;
            ArrayList<TextDataset.Sample> list = new ArrayList<TextDataset.Sample>();
            for (TextDataset.Sample sample : samples) {
                if (sample.getSentenceLength() > bucketKeys[index]) {
                    if (!list.isEmpty()) {
                        this.buckets.add(list);
                        list = new ArrayList();
                    }
                    ++index;
                }
                list.add(sample);
            }
            if (!list.isEmpty()) {
                this.buckets.add(list);
            }
            for (int i = 0; i < this.buckets.size(); ++i) {
                List<TextDataset.Sample> bucket = this.buckets.get(i);
                for (int j = 0; j < bucket.size(); j += FixedBucketSampler.this.batchSize) {
                    this.bucketBatch.add(new int[]{i, j});
                }
            }
            if (FixedBucketSampler.this.shuffle) {
                Collections.shuffle(this.bucketBatch, RandomUtils.RANDOM);
                this.buckets.forEach(l -> Collections.shuffle(l, RandomUtils.RANDOM));
            }
        }

        @Override
        public boolean hasNext() {
            return this.current < this.bucketBatch.size();
        }

        @Override
        public List<Long> next() {
            int[] batch = this.bucketBatch.get(this.current);
            ArrayList<Long> ret = new ArrayList<Long>();
            List<TextDataset.Sample> bucket = this.buckets.get(batch[0]);
            int end = Math.min(bucket.size(), batch[1] + FixedBucketSampler.this.batchSize);
            for (int i = batch[1]; i < end; ++i) {
                ret.add(bucket.get(i).getIndex());
            }
            ++this.current;
            return ret;
        }
    }
}

