/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset.cv.classification;

import ai.djl.basicdataset.cv.ImageDataset;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.Record;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslatorOptions;
import java.io.IOException;
import java.util.List;
import java.util.Optional;

public abstract class ImageClassificationDataset
extends ImageDataset {
    public ImageClassificationDataset(ImageDataset.BaseBuilder<?> builder) {
        super(builder);
    }

    protected abstract long getClassNumber(long var1) throws IOException;

    public Record get(NDManager manager, long index) throws IOException {
        NDList data = new NDList(new NDArray[]{this.getRecordImage(manager, index)});
        NDList label = new NDList(new NDArray[]{manager.create(this.getClassNumber(index))});
        return new Record(data, label);
    }

    public TranslatorOptions matchingTranslatorOptions() {
        Pipeline pipeline = new Pipeline();
        Optional<Integer> width = this.getImageWidth();
        Optional<Integer> height = this.getImageHeight();
        if (width.isPresent() && height.isPresent()) {
            pipeline.add((Transform)new Resize(width.get().intValue(), height.get().intValue()));
        }
        pipeline.add((Transform)new ToTensor());
        return ((ImageClassificationTranslator.Builder)((ImageClassificationTranslator.Builder)ImageClassificationTranslator.builder().optSynset(this.getClasses())).setPipeline(pipeline)).build().getExpansions();
    }

    public abstract List<String> getClasses();
}

