/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset.nlp;

import ai.djl.Application;
import ai.djl.basicdataset.RawDataset;
import ai.djl.basicdataset.nlp.TextDataset;
import ai.djl.basicdataset.utils.TextData;
import ai.djl.modality.nlp.embedding.EmbeddingException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.training.dataset.Record;
import ai.djl.util.JsonUtils;
import ai.djl.util.Progress;
import com.google.gson.reflect.TypeToken;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class StanfordQuestionAnsweringDataset
extends TextDataset
implements RawDataset<Object> {
    private static final String VERSION = "2.0";
    private static final String ARTIFACT_ID = "stanford-question-answer";
    private List<QuestionInfo> questionInfoList;

    protected StanfordQuestionAnsweringDataset(Builder builder) {
        super(builder);
        this.usage = builder.usage;
        this.mrl = builder.getMrl();
    }

    public static Builder builder() {
        return new Builder();
    }

    private Path prepareUsagePath(Progress progress) throws IOException {
        Path usagePath;
        Artifact artifact = this.mrl.getDefaultArtifact();
        this.mrl.prepare(artifact, progress);
        Path root = this.mrl.getRepository().getResourceDirectory(artifact);
        switch (this.usage) {
            case TRAIN: {
                usagePath = Paths.get("train-v2.0.json", new String[0]);
                break;
            }
            case TEST: {
                usagePath = Paths.get("dev-v2.0.json", new String[0]);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Validation data not available.");
            }
        }
        return root.resolve(usagePath);
    }

    public void prepare(Progress progress) throws IOException, EmbeddingException {
        Map data;
        if (this.prepared) {
            return;
        }
        Path usagePath = this.prepareUsagePath(progress);
        try (BufferedReader reader = Files.newBufferedReader(usagePath);){
            data = (Map)JsonUtils.GSON_PRETTY.fromJson((Reader)reader, new TypeToken<Map<String, Object>>(){}.getType());
        }
        List articles = (List)data.get("data");
        this.questionInfoList = new ArrayList<QuestionInfo>();
        ArrayList<String> sourceTextData = new ArrayList<String>();
        ArrayList<String> targetTextData = new ArrayList<String>();
        for (Map article : articles) {
            int titleIndex = sourceTextData.size();
            sourceTextData.add(article.get("title").toString());
            List paragraphs = (List)article.get("paragraphs");
            for (Map paragraph : paragraphs) {
                int contextIndex = sourceTextData.size();
                sourceTextData.add(paragraph.get("context").toString());
                List questions = (List)paragraph.get("qas");
                for (Map question : questions) {
                    int questionIndex = sourceTextData.size();
                    sourceTextData.add(question.get("question").toString());
                    QuestionInfo questionInfo = new QuestionInfo(questionIndex, titleIndex, contextIndex);
                    this.questionInfoList.add(questionInfo);
                    List answers = (List)question.get("answers");
                    for (Map answer : answers) {
                        int answerIndex = targetTextData.size();
                        targetTextData.add(answer.get("text").toString());
                        questionInfo.addAnswer(answerIndex);
                    }
                }
            }
        }
        this.preprocess(sourceTextData, true);
        this.preprocess(targetTextData, false);
        this.prepared = true;
    }

    public Record get(NDManager manager, long index) {
        NDList data = new NDList();
        NDList labels = new NDList();
        QuestionInfo questionInfo = this.questionInfoList.get(Math.toIntExact(index));
        NDArray title = this.sourceTextData.getEmbedding(manager, questionInfo.titleIndex.intValue());
        title.setName("title");
        NDArray context = this.sourceTextData.getEmbedding(manager, questionInfo.contextIndex.intValue());
        context.setName("context");
        NDArray question = this.sourceTextData.getEmbedding(manager, questionInfo.questionIndex.intValue());
        question.setName("question");
        data.add((Object)title);
        data.add((Object)context);
        data.add((Object)question);
        for (Integer answerIndex : questionInfo.answerIndexList) {
            labels.add((Object)this.targetTextData.getEmbedding(manager, answerIndex.intValue()));
        }
        return new Record(data, labels);
    }

    protected long availableSize() {
        return this.questionInfoList.size();
    }

    @Override
    public Object getData() throws IOException {
        Object data;
        Path usagePath = this.prepareUsagePath(null);
        try (BufferedReader reader = Files.newBufferedReader(usagePath);){
            data = JsonUtils.GSON_PRETTY.fromJson((Reader)reader, new TypeToken<Object>(){}.getType());
        }
        return data;
    }

    private int getLastAnswerIndex(int questionInfoIndex) {
        while (questionInfoIndex >= 0) {
            QuestionInfo questionInfo = this.questionInfoList.get(questionInfoIndex);
            if (!questionInfo.answerIndexList.isEmpty()) {
                return questionInfo.answerIndexList.get(questionInfo.answerIndexList.size() - 1);
            }
            --questionInfoIndex;
        }
        return 0;
    }

    @Override
    protected void preprocess(List<String> newTextData, boolean source) throws EmbeddingException {
        TextData textData = source ? this.sourceTextData : this.targetTextData;
        int index = (int)Math.min(this.limit, (long)this.questionInfoList.size()) - 1;
        int lastIndex = source ? this.questionInfoList.get((int)index).questionIndex.intValue() : this.getLastAnswerIndex(index);
        textData.preprocess(this.manager, newTextData.subList(0, lastIndex + 1));
    }

    private static class QuestionInfo {
        Integer questionIndex;
        Integer titleIndex;
        Integer contextIndex;
        List<Integer> answerIndexList;

        QuestionInfo(Integer questionIndex, Integer titleIndex, Integer contextIndex) {
            this.questionIndex = questionIndex;
            this.titleIndex = titleIndex;
            this.contextIndex = contextIndex;
            this.answerIndexList = new ArrayList<Integer>();
        }

        void addAnswer(Integer answerIndex) {
            this.answerIndexList.add(answerIndex);
        }
    }

    public static class Builder
    extends TextDataset.Builder<Builder> {
        public Builder() {
            this.artifactId = StanfordQuestionAnsweringDataset.ARTIFACT_ID;
        }

        public Builder self() {
            return this;
        }

        public StanfordQuestionAnsweringDataset build() {
            return new StanfordQuestionAnsweringDataset(this);
        }

        MRL getMrl() {
            return this.repository.dataset(Application.NLP.ANY, this.groupId, this.artifactId, StanfordQuestionAnsweringDataset.VERSION);
        }
    }
}

