/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset;

import ai.djl.Application;
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.Resource;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.util.JsonUtils;
import ai.djl.util.Progress;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AmesRandomAccess
extends RandomAccessDataset {
    private static final Logger logger = LoggerFactory.getLogger(AmesRandomAccess.class);
    private static final String ARTIFACT_ID = "ames";
    private Set<String> enabledFeatures;
    private Set<String> categoricalFeatures;
    private Set<String> disabledFeatures;
    private Set<String> oneHotEncode;
    private Map<String, Map<String, Integer>> featureToMap;
    private String label;
    private Map<String, FeatureType> featureType;
    private List<CSVRecord> csvRecords;
    private Dataset.Usage usage;
    private Resource resource;
    private boolean prepared;

    AmesRandomAccess(Builder builder) {
        super((RandomAccessDataset.BaseBuilder)builder);
        this.usage = builder.usage;
        MRL mrl = MRL.dataset((Application)Application.Tabular.LINEAR_REGRESSION, (String)builder.groupId, (String)builder.artifactId);
        this.resource = new Resource(builder.repository, mrl, "1.0");
        this.label = "saleprice";
        this.categoricalFeatures = builder.af.categorical;
        this.enabledFeatures = new HashSet<String>(builder.af.featureArray);
        this.featureToMap = new ConcurrentHashMap<String, Map<String, Integer>>(builder.af.featureToMap);
        this.disabledFeatures = new HashSet<String>();
        this.featureType = new ConcurrentHashMap<String, FeatureType>();
        this.oneHotEncode = new HashSet<String>();
    }

    public static Builder builder() {
        return new Builder();
    }

    public void setLabel(String feature) {
        if (this.disabledFeatures.remove(feature = feature.toLowerCase())) {
            this.label = feature;
        }
    }

    public Set<String> getEnabledFeatures() {
        return this.enabledFeatures;
    }

    public Set<String> getDisabledFeatures() {
        return this.disabledFeatures;
    }

    public Set<String> getCategoricalFeatures() {
        return this.categoricalFeatures;
    }

    public float[] getLabel(int index) {
        return new float[]{Float.parseFloat(this.csvRecords.get(index).get(this.label))};
    }

    public Record get(NDManager manager, long index) {
        int idx = Math.toIntExact(index);
        NDList d = new NDList(new NDArray[]{this.getFeatureNDArray(manager, idx)});
        NDList l = new NDList(new NDArray[]{manager.create(this.getLabel(idx))});
        return new Record(d, l);
    }

    public CSVRecord getCSVRecord(int index) {
        return this.csvRecords.get(index);
    }

    public void setOneHotEncode(String feature, boolean enable) {
        if (this.featureType.get(feature) == FeatureType.CATEGORICAL) {
            if (enable) {
                this.oneHotEncode.add(feature);
            } else {
                this.oneHotEncode.remove(feature);
            }
        }
    }

    public float[] getValueFloat(CSVRecord record, String feature) {
        if (this.featureType.get(feature) == FeatureType.NUMERIC) {
            return new float[]{Float.parseFloat(record.get(feature))};
        }
        String value = record.get(feature);
        if (this.featureToMap.containsKey(feature)) {
            Map<String, Integer> categoryTypeToInteger = this.featureToMap.get(feature);
            if (this.oneHotEncode.contains(feature)) {
                int categoryTypeCount = categoryTypeToInteger.size();
                float[] oneHotVector = new float[categoryTypeCount];
                oneHotVector[categoryTypeToInteger.get((Object)value).intValue()] = 1.0f;
                return oneHotVector;
            }
            if (categoryTypeToInteger.containsKey(value)) {
                return new float[]{categoryTypeToInteger.get(value).intValue()};
            }
            categoryTypeToInteger.put(value, categoryTypeToInteger.size());
            return new float[]{categoryTypeToInteger.size() - 1};
        }
        ConcurrentHashMap<String, Integer> map = new ConcurrentHashMap<String, Integer>();
        this.featureToMap.put(feature, map);
        map.put(value, 0);
        return new float[]{0.0f};
    }

    public int getFeatureArraySize() {
        int size = this.enabledFeatures.size();
        for (String feature : this.oneHotEncode) {
            if (!this.enabledFeatures.contains(feature)) continue;
            size += this.featureToMap.get(feature).size() - 1;
        }
        return size;
    }

    public NDArray getFeatureNDArray(NDManager manager, int index) {
        CSVRecord record = this.getCSVRecord(index);
        float[] featureArray = new float[this.getFeatureArraySize()];
        int i = 0;
        for (String feature : this.enabledFeatures) {
            float[] values;
            float[] fArray = values = this.getValueFloat(record, feature);
            int n = fArray.length;
            for (int j = 0; j < n; ++j) {
                float value;
                featureArray[i] = value = fArray[j];
                ++i;
            }
        }
        return manager.create(featureArray);
    }

    public void removeAllFeatures() {
        this.disabledFeatures.addAll(this.enabledFeatures);
        this.enabledFeatures.clear();
    }

    public void addAllFeatures() {
        for (String feature : this.disabledFeatures) {
            this.addFeature(feature);
        }
    }

    public void addFeature(String feature, FeatureType type) {
        if (this.disabledFeatures.contains(feature = feature.toLowerCase(Locale.getDefault()))) {
            this.featureType.put(feature, type);
            this.enabledFeatures.add(feature);
            this.disabledFeatures.remove(feature);
        } else {
            logger.warn("Unsupported feature: {}", (Object)feature);
        }
    }

    public void addFeature(String feature) {
        feature = feature.toLowerCase(Locale.getDefault());
        this.addFeature(feature, this.getFeatureType(feature));
    }

    public void setFeatureType(String feature, FeatureType type) {
        this.featureType.put(feature, type);
    }

    public void removeFeature(String feature) {
        if (this.enabledFeatures.contains(feature = feature.toLowerCase())) {
            this.disabledFeatures.add(feature);
            this.enabledFeatures.remove(feature);
        }
    }

    public FeatureType getFeatureType(String feature) {
        if (this.categoricalFeatures.contains(feature)) {
            return FeatureType.CATEGORICAL;
        }
        return FeatureType.NUMERIC;
    }

    public void prepare(Progress progress) throws IOException {
        Path csvFile;
        if (this.prepared) {
            return;
        }
        Artifact artifact = this.resource.getDefaultArtifact();
        this.resource.prepare(artifact, progress);
        Path root = this.resource.getRepository().getResourceDirectory(artifact).resolve("house-prices-advanced-regression-techniques");
        switch (this.usage) {
            case TRAIN: {
                csvFile = root.resolve("train.csv");
                break;
            }
            case TEST: {
                csvFile = root.resolve("test.csv");
                break;
            }
            default: {
                throw new UnsupportedOperationException("Validation data not available.");
            }
        }
        try (BufferedReader reader = Files.newBufferedReader(csvFile);
             CSVParser csvParser = new CSVParser((Reader)reader, CSVFormat.DEFAULT.withFirstRecordAsHeader().withIgnoreHeaderCase().withTrim());){
            this.csvRecords = csvParser.getRecords();
        }
        this.prepared = true;
    }

    protected long availableSize() {
        return this.csvRecords.size();
    }

    public static enum FeatureType {
        NUMERIC,
        CATEGORICAL;

    }

    private static final class AmesFeatures {
        List<String> featureArray;
        Set<String> categorical;
        Map<String, Map<String, Integer>> featureToMap;

        private AmesFeatures() {
        }
    }

    public static final class Builder
    extends RandomAccessDataset.BaseBuilder<Builder> {
        Repository repository = BasicDatasets.REPOSITORY;
        String groupId = "ai.djl.basicdataset";
        String artifactId = "ames";
        Dataset.Usage usage = Dataset.Usage.TRAIN;
        AmesFeatures af;

        Builder() {
        }

        public Builder self() {
            return this;
        }

        public Builder optUsage(Dataset.Usage usage) {
            this.usage = usage;
            return this.self();
        }

        public Builder optRepository(Repository repository) {
            this.repository = repository;
            return this.self();
        }

        public Builder optGroupId(String groupId) {
            this.groupId = groupId;
            return this.self();
        }

        public Builder optArtifactId(String artifactId) {
            if (artifactId.contains(":")) {
                String[] tokens = artifactId.split(":");
                this.groupId = tokens[0];
                this.artifactId = tokens[1];
            } else {
                this.artifactId = artifactId;
            }
            return this.self();
        }

        public AmesRandomAccess build() throws IOException {
            try (InputStreamReader reader = new InputStreamReader(AmesRandomAccess.class.getResourceAsStream("ames.json"), StandardCharsets.UTF_8);){
                this.af = (AmesFeatures)JsonUtils.GSON.fromJson((Reader)reader, AmesFeatures.class);
            }
            return new AmesRandomAccess(this);
        }
    }
}

