/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.dataset;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.training.dataset.Sampler;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataIterable
implements Iterable<Batch>,
Iterator<Batch> {
    private static final Logger logger = LoggerFactory.getLogger(DataIterable.class);
    private RandomAccessDataset dataset;
    private NDManager manager;
    private Batchifier batchifier;
    private Pipeline pipeline;
    private Pipeline targetPipeline;
    private ExecutorService executor;
    private Device device;
    private Iterator<List<Long>> sample;
    private Queue<Future<Batch>> queue;
    private AtomicInteger progressCounter;

    public DataIterable(RandomAccessDataset dataset, NDManager manager, Sampler sampler, Batchifier batchifier, Pipeline pipeline, Pipeline targetPipeline, ExecutorService executor, int preFetchNumber, Device device) {
        this.dataset = dataset;
        this.manager = manager.newSubManager();
        this.batchifier = batchifier;
        this.pipeline = pipeline;
        this.targetPipeline = targetPipeline;
        this.executor = executor;
        this.device = device;
        this.progressCounter = new AtomicInteger(0);
        this.sample = sampler.sample(dataset);
        if (executor != null) {
            this.queue = new LinkedList<Future<Batch>>();
            for (int i = 0; i < preFetchNumber; ++i) {
                this.preFetch();
            }
        }
    }

    @Override
    public Iterator<Batch> iterator() {
        return this;
    }

    @Override
    public boolean hasNext() {
        if (this.executor != null) {
            if (this.queue.isEmpty()) {
                this.manager.close();
                return false;
            }
            return true;
        }
        if (!this.sample.hasNext()) {
            this.manager.close();
            return false;
        }
        return true;
    }

    @Override
    public Batch next() {
        if (this.executor == null) {
            List<Long> indices = this.sample.next();
            try {
                int progress = this.progressCounter.addAndGet(indices.size());
                return this.fetch(indices, progress);
            }
            catch (TranslateException | IOException e) {
                logger.error(e.getMessage());
                throw new IllegalStateException("Data loading failed", e);
            }
        }
        this.preFetch();
        Future<Batch> future = this.queue.poll();
        try {
            return future.get();
        }
        catch (InterruptedException | ExecutionException e) {
            logger.error(e.getMessage());
            throw new IllegalStateException("Data loading failed", e);
        }
    }

    private Batch fetch(List<Long> indices, int progress) throws IOException, TranslateException {
        NDManager subManager = this.manager.newSubManager();
        int batchSize = indices.size();
        NDList[] data = new NDList[batchSize];
        NDList[] labels = new NDList[batchSize];
        for (int i = 0; i < batchSize; ++i) {
            Record record = this.dataset.get(subManager, indices.get(i));
            data[i] = record.getData();
            if (this.pipeline != null) {
                data[i] = this.pipeline.transform(data[i]);
            }
            labels[i] = record.getLabels();
        }
        NDList batchData = this.batchifier.batchify(data);
        NDList batchLabels = this.batchifier.batchify(labels);
        Arrays.stream(data).forEach(NDList::close);
        Arrays.stream(labels).forEach(NDList::close);
        if (this.targetPipeline != null) {
            batchLabels = this.targetPipeline.transform(batchLabels);
        }
        if (this.device != null) {
            batchData = batchData.asInDevice(this.device, false);
            batchLabels = batchLabels.asInDevice(this.device, false);
        }
        return new Batch(subManager, batchData, batchLabels, batchSize, this.batchifier, progress, this.dataset.size());
    }

    private void preFetch() {
        if (!this.sample.hasNext()) {
            return;
        }
        List<Long> indices = this.sample.next();
        PreFetchCallable task = new PreFetchCallable(indices);
        Future<Batch> result = this.executor.submit(task);
        this.queue.offer(result);
    }

    class PreFetchCallable
    implements Callable<Batch> {
        private List<Long> indices;
        private int progress;

        public PreFetchCallable(List<Long> indices) {
            this.indices = indices;
            this.progress = DataIterable.this.progressCounter.getAndAdd(indices.size());
        }

        @Override
        public Batch call() throws IOException, TranslateException {
            return DataIterable.this.fetch(this.indices, this.progress);
        }
    }
}

