/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset.tabular;

import ai.djl.basicdataset.tabular.TabularDataset;
import ai.djl.util.Progress;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.Collections;
import java.util.List;
import java.util.zip.GZIPInputStream;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;

public class CsvDataset
extends TabularDataset {
    protected URL csvUrl;
    protected CSVFormat csvFormat;
    protected List<CSVRecord> csvRecords;

    protected CsvDataset(CsvBuilder<?> builder) {
        super(builder);
        this.csvUrl = builder.csvUrl;
        this.csvFormat = builder.csvFormat;
    }

    @Override
    protected String getCell(long rowIndex, String featureName) {
        CSVRecord record = this.csvRecords.get(Math.toIntExact(rowIndex));
        return record.get(featureName);
    }

    protected long availableSize() {
        return this.csvRecords.size();
    }

    public void prepare(Progress progress) throws IOException {
        try (InputStreamReader reader = new InputStreamReader(this.getCsvStream(), StandardCharsets.UTF_8);){
            CSVParser csvParser = new CSVParser((Reader)reader, this.csvFormat);
            this.csvRecords = csvParser.getRecords();
        }
        this.prepareFeaturizers();
    }

    private InputStream getCsvStream() throws IOException {
        if (this.csvUrl.getFile().endsWith(".gz")) {
            return new GZIPInputStream(this.csvUrl.openStream());
        }
        return new BufferedInputStream(this.csvUrl.openStream());
    }

    public static CsvBuilder<?> builder() {
        return new CsvBuilder();
    }

    public List<String> getColumnNames() {
        if (this.csvRecords.isEmpty()) {
            return Collections.emptyList();
        }
        return this.csvRecords.get(0).getParser().getHeaderNames();
    }

    public static class CsvBuilder<T extends CsvBuilder<T>>
    extends TabularDataset.BaseBuilder<T> {
        protected URL csvUrl;
        protected CSVFormat csvFormat;

        protected T self() {
            return (T)((Object)this);
        }

        public T optCsvFile(Path csvFile) {
            try {
                this.csvUrl = csvFile.toAbsolutePath().toUri().toURL();
            }
            catch (MalformedURLException e) {
                throw new IllegalArgumentException("Invalid file path: " + csvFile, e);
            }
            return (T)this.self();
        }

        public T optCsvUrl(String csvUrl) {
            try {
                this.csvUrl = new URL(csvUrl);
            }
            catch (MalformedURLException e) {
                throw new IllegalArgumentException("Invalid url: " + csvUrl, e);
            }
            return (T)this.self();
        }

        public T setCsvFormat(CSVFormat csvFormat) {
            this.csvFormat = csvFormat;
            return (T)this.self();
        }

        public CsvDataset build() {
            return new CsvDataset(this);
        }
    }
}

