/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset.cv.classification;

import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import java.io.IOException;
import java.util.List;
import java.util.Optional;

public abstract class ImageClassificationDataset
extends RandomAccessDataset {
    Image.Flag flag;

    public ImageClassificationDataset(BaseBuilder<?> builder) {
        super(builder);
        this.flag = builder.flag;
    }

    protected abstract Image getImage(long var1) throws IOException;

    protected abstract long getClassNumber(long var1) throws IOException;

    public Record get(NDManager manager, long index) throws IOException {
        NDArray image = this.getImage(index).toNDArray(manager, this.flag);
        Optional<Integer> width = this.getImageWidth();
        Optional<Integer> height = this.getImageHeight();
        if (width.isPresent() && height.isPresent()) {
            image = NDImageUtils.resize((NDArray)image, (int)width.get(), (int)height.get());
        }
        NDList data = new NDList(new NDArray[]{image});
        NDList label = new NDList(new NDArray[]{manager.create(this.getClassNumber(index))});
        return new Record(data, label);
    }

    public Translator<Image, Classifications> makeTranslator() {
        Pipeline pipeline = new Pipeline();
        Optional<Integer> width = this.getImageWidth();
        Optional<Integer> height = this.getImageHeight();
        if (width.isPresent() && height.isPresent()) {
            pipeline.add((Transform)new Resize(width.get().intValue(), height.get().intValue()));
        }
        pipeline.add((Transform)new ToTensor());
        return ((ImageClassificationTranslator.Builder)((ImageClassificationTranslator.Builder)ImageClassificationTranslator.builder().optSynset(this.getClasses())).setPipeline(pipeline)).build();
    }

    public int getImageChannels() {
        return this.flag.numChannels();
    }

    public abstract Optional<Integer> getImageWidth();

    public abstract Optional<Integer> getImageHeight();

    public abstract List<String> getClasses();

    public static abstract class BaseBuilder<T extends BaseBuilder<T>>
    extends RandomAccessDataset.BaseBuilder<T> {
        Image.Flag flag = Image.Flag.COLOR;

        protected BaseBuilder() {
        }

        public T optFlag(Image.Flag flag) {
            this.flag = flag;
            return (T)((Object)((BaseBuilder)this.self()));
        }
    }
}

