/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.dataset;

import ai.djl.Device;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.BatchSampler;
import ai.djl.training.dataset.DataIterable;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomSampler;
import ai.djl.training.dataset.Record;
import ai.djl.training.dataset.Sampler;
import ai.djl.training.dataset.SequenceSampler;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import java.io.IOException;
import java.util.RandomAccess;
import java.util.concurrent.ExecutorService;

public abstract class RandomAccessDataset
implements Dataset,
RandomAccess {
    protected Sampler sampler;
    protected Batchifier batchifier;
    protected Pipeline pipeline;
    protected Pipeline targetPipeline;
    protected ExecutorService executor;
    protected int prefetchNumber;
    private long maxIteration;
    protected Device device;

    public RandomAccessDataset(BaseBuilder<?> builder) {
        this.sampler = builder.getSampler();
        this.batchifier = builder.batchifier;
        this.pipeline = builder.pipeline;
        this.targetPipeline = builder.targetPipeline;
        this.executor = builder.executor;
        this.prefetchNumber = builder.prefetchNumber;
        this.maxIteration = builder.maxIteration;
        this.device = builder.device;
    }

    public abstract Record get(NDManager var1, long var2) throws IOException;

    @Override
    public Iterable<Batch> getData(NDManager manager) {
        return new DataIterable(this, manager, this.sampler, this.batchifier, this.pipeline, this.targetPipeline, this.executor, this.prefetchNumber, this.maxIteration, this.device);
    }

    public abstract long size();

    public static abstract class BaseBuilder<T extends BaseBuilder> {
        protected Sampler sampler;
        protected Batchifier batchifier = Batchifier.STACK;
        protected Pipeline pipeline;
        protected Pipeline targetPipeline;
        protected ExecutorService executor;
        protected int prefetchNumber;
        protected long maxIteration = Long.MAX_VALUE;
        protected Device device;

        public Sampler getSampler() {
            if (this.sampler == null) {
                throw new IllegalArgumentException("The sampler must be set");
            }
            return this.sampler;
        }

        public T setSampling(long batchSize, boolean random) {
            return this.setSampling(batchSize, random, false);
        }

        public T setSampling(long batchSize, boolean random, boolean dropLast) {
            this.sampler = random ? new BatchSampler(new RandomSampler(), batchSize, dropLast) : new BatchSampler(new SequenceSampler(), batchSize, dropLast);
            return this.self();
        }

        public T setSampling(Sampler sampler) {
            this.sampler = sampler;
            return this.self();
        }

        public T optBatchier(Batchifier batchier) {
            this.batchifier = batchier;
            return this.self();
        }

        public T optPipeline(Pipeline pipeline) {
            this.pipeline = pipeline;
            return this.self();
        }

        public T optTargetPipeline(Pipeline targetPipeline) {
            this.targetPipeline = targetPipeline;
            return this.self();
        }

        public T optExcutor(ExecutorService executor, int prefetchNumber) {
            this.executor = executor;
            this.prefetchNumber = prefetchNumber;
            return this.self();
        }

        public T optDevice(Device device) {
            this.device = device;
            return this.self();
        }

        public T optMaxIteration(long maxIteration) {
            this.maxIteration = maxIteration;
            return this.self();
        }

        protected abstract T self();
    }
}

