/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset;

import ai.djl.basicdataset.utils.DynamicBuffer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.util.Progress;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.Buffer;
import java.nio.FloatBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.GZIPInputStream;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;

public class CsvDataset
extends RandomAccessDataset {
    private static final Featurizer NUMERIC_FEATURIZER = new NumericFeaturizer();
    protected URL csvUrl;
    protected CSVFormat csvFormat;
    protected List<Feature> features;
    protected List<Feature> labels;
    protected List<CSVRecord> csvRecords;

    protected CsvDataset(CsvBuilder<?> builder) {
        super(builder);
        this.csvUrl = builder.csvUrl;
        this.csvFormat = builder.csvFormat;
        this.features = builder.features;
        this.labels = builder.labels;
    }

    public Record get(NDManager manager, long index) {
        CSVRecord csvRecord = this.csvRecords.get(Math.toIntExact(index));
        NDList data = this.toNDList(manager, csvRecord, this.features);
        NDList label = this.toNDList(manager, csvRecord, this.labels);
        return new Record(data, label);
    }

    protected long availableSize() {
        return this.csvRecords.size();
    }

    public void prepare(Progress progress) throws IOException {
        try (InputStreamReader reader = new InputStreamReader(this.getCsvStream(), StandardCharsets.UTF_8);){
            CSVParser csvParser = new CSVParser((Reader)reader, this.csvFormat);
            this.csvRecords = csvParser.getRecords();
        }
    }

    private InputStream getCsvStream() throws IOException {
        if (this.csvUrl.getFile().endsWith(".gz")) {
            return new GZIPInputStream(this.csvUrl.openStream());
        }
        return this.csvUrl.openStream();
    }

    public static CsvBuilder<?> builder() {
        return new CsvBuilder();
    }

    public List<String> getColumnNames() {
        if (this.csvRecords.isEmpty()) {
            return Collections.emptyList();
        }
        return this.csvRecords.get(0).getParser().getHeaderNames();
    }

    protected NDList toNDList(NDManager manager, CSVRecord record, List<Feature> selected) {
        DynamicBuffer bb = new DynamicBuffer();
        for (Feature feature : selected) {
            String name = feature.getName();
            String value = record.get(name);
            feature.featurizer.featurize(bb, value);
        }
        FloatBuffer buf = bb.getBuffer();
        return new NDList(new NDArray[]{manager.create((Buffer)buf, new Shape(new long[]{bb.getLength()}))});
    }

    private static final class StringFeaturizer
    implements Featurizer {
        private Map<String, Integer> map;
        private boolean onehotEncode;
        private boolean autoMap;

        StringFeaturizer() {
            this.map = new HashMap<String, Integer>();
            this.autoMap = true;
        }

        StringFeaturizer(Map<String, Integer> map, boolean onehotEncode) {
            this.map = map;
            this.onehotEncode = onehotEncode;
        }

        @Override
        public void featurize(DynamicBuffer buf, String input) {
            if (this.onehotEncode) {
                for (int i = 0; i < this.map.size(); ++i) {
                    buf.put(i == this.map.get(input) ? 1.0f : 0.0f);
                }
                return;
            }
            Integer index = this.map.get(input);
            if (index != null) {
                buf.put(index.intValue());
                return;
            }
            if (!this.autoMap) {
                throw new IllegalArgumentException("Value: " + input + " not found in the map.");
            }
            int value = this.map.size();
            this.map.put(input, value);
            buf.put(value);
        }
    }

    private static final class NumericFeaturizer
    implements Featurizer {
        private NumericFeaturizer() {
        }

        @Override
        public void featurize(DynamicBuffer buf, String input) {
            buf.put(Float.parseFloat(input));
        }
    }

    public static final class Feature {
        String name;
        Featurizer featurizer;

        public Feature(String name, Featurizer featurizer) {
            this.name = name;
            this.featurizer = featurizer;
        }

        public Feature(String name, boolean numeric) {
            this.name = name;
            this.featurizer = numeric ? NUMERIC_FEATURIZER : new StringFeaturizer();
        }

        public Feature(String name, Map<String, Integer> map, boolean onehotEncode) {
            this.name = name;
            this.featurizer = new StringFeaturizer(map, onehotEncode);
        }

        public String getName() {
            return this.name;
        }

        public Featurizer getFeaturizer() {
            return this.featurizer;
        }
    }

    public static interface Featurizer {
        public void featurize(DynamicBuffer var1, String var2);
    }

    public static class CsvBuilder<T extends CsvBuilder<T>>
    extends RandomAccessDataset.BaseBuilder<T> {
        protected URL csvUrl;
        protected CSVFormat csvFormat;
        protected List<Feature> features = new ArrayList<Feature>();
        protected List<Feature> labels = new ArrayList<Feature>();

        protected CsvBuilder() {
        }

        protected T self() {
            return (T)((Object)this);
        }

        public T optCsvFile(Path csvFile) {
            try {
                this.csvUrl = csvFile.toAbsolutePath().toUri().toURL();
            }
            catch (MalformedURLException e) {
                throw new IllegalArgumentException("Invalid file path: " + csvFile, e);
            }
            return (T)this.self();
        }

        public T optCsvUrl(String csvUrl) {
            try {
                this.csvUrl = new URL(csvUrl);
            }
            catch (MalformedURLException e) {
                throw new IllegalArgumentException("Invalid url: " + csvUrl, e);
            }
            return (T)this.self();
        }

        public T setCsvFormat(CSVFormat csvFormat) {
            this.csvFormat = csvFormat;
            return (T)this.self();
        }

        public T addFeature(Feature ... features) {
            Collections.addAll(this.features, features);
            return (T)this.self();
        }

        public T addNumericFeature(String name) {
            this.features.add(new Feature(name, true));
            return (T)this.self();
        }

        public T addCategoricalFeature(String name) {
            this.features.add(new Feature(name, false));
            return (T)this.self();
        }

        public T addCategoricalFeature(String name, Map<String, Integer> map, boolean onehotEncode) {
            this.features.add(new Feature(name, map, onehotEncode));
            return (T)this.self();
        }

        public T addLabel(Feature ... labels) {
            Collections.addAll(this.labels, labels);
            return (T)this.self();
        }

        public T addNumericLabel(String name) {
            this.labels.add(new Feature(name, true));
            return (T)this.self();
        }

        public T addCategoricalLabel(String name) {
            this.labels.add(new Feature(name, true));
            return (T)this.self();
        }

        public T addCategoricalLabel(String name, Map<String, Integer> map, boolean onehotEncode) {
            this.labels.add(new Feature(name, map, onehotEncode));
            return (T)this.self();
        }

        public CsvDataset build() {
            if (this.features.isEmpty()) {
                throw new IllegalArgumentException("Missing features.");
            }
            if (this.labels.isEmpty()) {
                throw new IllegalArgumentException("Missing labels.");
            }
            return new CsvDataset(this);
        }
    }
}

