/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset;

import ai.djl.Application;
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.modality.nlp.EmbeddingException;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.WordEmbedding;
import ai.djl.modality.nlp.preprocess.LowerCaseConvertor;
import ai.djl.modality.nlp.preprocess.PunctuationSeparator;
import ai.djl.modality.nlp.preprocess.SentenceLengthNormalizer;
import ai.djl.modality.nlp.preprocess.TextProcessor;
import ai.djl.modality.nlp.preprocess.Tokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.dataset.ZooDataset;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import java.io.BufferedReader;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;

public class TatoebaEnglishFrenchDataset
extends RandomAccessDataset
implements ZooDataset {
    private static final String VERSION = "1.0";
    private static final String ARTIFACT_ID = "tatoeba-en-fr";
    private Repository repository;
    private Artifact artifact;
    private Dataset.Usage usage;
    private boolean prepared;
    private List<List<String>> sourceSentences;
    private List<Integer> sourceValidLength;
    private List<List<String>> targetSentences;
    private List<Integer> targetValidLength;
    private List<TextProcessor> sourceTextProcessors;
    private List<TextProcessor> targetTextProcessors;
    private WordEmbedding wordEmbedding;
    private boolean trainEmbedding;
    private boolean includeValidLength;
    private Tokenizer tokenizer;

    protected TatoebaEnglishFrenchDataset(Builder builder) {
        super((RandomAccessDataset.BaseBuilder)builder);
        this.repository = builder.repository;
        this.artifact = builder.artifact;
        this.usage = builder.usage;
        this.wordEmbedding = builder.wordEmbedding;
        this.trainEmbedding = builder.trainEmbedding;
        this.includeValidLength = builder.includeValidLength;
        this.sourceTextProcessors = builder.sourceTextProcessors;
        this.targetTextProcessors = builder.targetTextProcessors;
        this.tokenizer = builder.tokenizer;
        this.sourceSentences = new ArrayList<List<String>>();
        this.sourceValidLength = new ArrayList<Integer>();
        this.targetSentences = new ArrayList<List<String>>();
        this.targetValidLength = new ArrayList<Integer>();
    }

    public static Builder builder() {
        return new Builder();
    }

    public MRL getMrl() {
        return MRL.dataset((Application)Application.NLP.MACHINE_TRANSLATION, (String)"ai.djl.basicdataset", (String)ARTIFACT_ID);
    }

    public Repository getRepository() {
        return this.repository;
    }

    public Artifact getArtifact() {
        return this.artifact;
    }

    public Dataset.Usage getUsage() {
        return this.usage;
    }

    public boolean isPrepared() {
        return this.prepared;
    }

    public void setPrepared(boolean prepared) {
        this.prepared = prepared;
    }

    public void useDefaultArtifact() throws IOException {
        this.artifact = this.repository.resolve(this.getMrl(), VERSION, null);
    }

    public void prepareData(Dataset.Usage usage) throws IOException {
        List<String> sentence;
        int i;
        Path usagePath;
        Path cacheDir = this.repository.getCacheDirectory();
        URI resourceUri = this.artifact.getResourceUri();
        Path root = cacheDir.resolve(resourceUri.getPath());
        switch (usage) {
            case TRAIN: {
                usagePath = Paths.get("fra-eng-train.txt", new String[0]);
                break;
            }
            case TEST: {
                usagePath = Paths.get("fra-eng-test.txt", new String[0]);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Validation data not available.");
            }
        }
        usagePath = root.resolve(usagePath);
        Vocabulary.VocabularyBuilder sourceVocabularyBuilder = new Vocabulary.VocabularyBuilder();
        sourceVocabularyBuilder.optMinFrequency(3);
        sourceVocabularyBuilder.optReservedTokens(Arrays.asList("<pad>", "<bos>", "<eos>"));
        Vocabulary.VocabularyBuilder targetVocabularyBuilder = new Vocabulary.VocabularyBuilder();
        targetVocabularyBuilder.optMinFrequency(3);
        targetVocabularyBuilder.optReservedTokens(Arrays.asList("<pad>", "<bos>", "<eos>"));
        try (BufferedReader reader = Files.newBufferedReader(usagePath);){
            String row;
            while ((row = reader.readLine()) != null) {
                String[] sentences = row.split("\t");
                List sourceSentence = this.tokenizer.tokenize(sentences[0]);
                for (TextProcessor processor : this.sourceTextProcessors) {
                    sourceSentence = processor.preprocess(sourceSentence);
                    if (!(processor instanceof SentenceLengthNormalizer)) continue;
                    this.sourceValidLength.add(((SentenceLengthNormalizer)processor).getLastValidLength());
                }
                List targetSentence = this.tokenizer.tokenize(sentences[1]);
                for (TextProcessor processor : this.targetTextProcessors) {
                    targetSentence = processor.preprocess(targetSentence);
                    if (!(processor instanceof SentenceLengthNormalizer)) continue;
                    this.targetValidLength.add(((SentenceLengthNormalizer)processor).getLastValidLength());
                }
                sourceVocabularyBuilder.add(sourceSentence);
                targetVocabularyBuilder.add(targetSentence);
                this.sourceSentences.add(sourceSentence);
                this.targetSentences.add(targetSentence);
            }
        }
        Vocabulary sourceVocabulary = sourceVocabularyBuilder.build();
        Vocabulary targetVocabulary = targetVocabularyBuilder.build();
        for (i = 0; i < this.sourceSentences.size(); ++i) {
            sentence = this.sourceSentences.get(i);
            for (int j = 0; j < sentence.size(); ++j) {
                if (sourceVocabulary.isKnownToken(sentence.get(j))) continue;
                sentence.set(j, sourceVocabulary.getUnknownToken());
            }
            this.sourceSentences.set(i, sentence);
        }
        for (i = 0; i < this.targetSentences.size(); ++i) {
            sentence = this.targetSentences.get(i);
            for (int j = 0; j < sentence.size(); ++j) {
                if (targetVocabulary.isKnownToken(sentence.get(j))) continue;
                sentence.set(j, targetVocabulary.getUnknownToken());
            }
            this.targetSentences.set(i, sentence);
        }
    }

    public Record get(NDManager manager, long index) throws EmbeddingException {
        NDList data = new NDList();
        NDList dataLengths = new NDList();
        NDList target = new NDList();
        NDList targetLengths = new NDList();
        List<String> sourceSentence = this.sourceSentences.get((int)index);
        List<String> targetSentence = this.targetSentences.get((int)index);
        for (String token : sourceSentence) {
            if (this.trainEmbedding) {
                data.add((Object)this.wordEmbedding.preprocessWordToEmbed(manager, token));
            } else {
                data.add((Object)this.wordEmbedding.embedWord(manager, token));
            }
            if (!this.includeValidLength) continue;
            dataLengths.add((Object)manager.create((Number)this.sourceValidLength.get((int)index)));
        }
        for (String token : targetSentence) {
            if (this.trainEmbedding) {
                target.add((Object)this.wordEmbedding.preprocessWordToEmbed(manager, token));
            } else {
                target.add((Object)this.wordEmbedding.embedWord(manager, token));
            }
            if (!this.includeValidLength) continue;
            targetLengths.add((Object)manager.create((Number)this.targetValidLength.get((int)index)));
        }
        if (this.includeValidLength) {
            return new Record(new NDList(new NDArray[]{NDArrays.stack((NDList)data), NDArrays.stack((NDList)dataLengths)}), new NDList(new NDArray[]{NDArrays.stack((NDList)target), NDArrays.stack((NDList)targetLengths)}));
        }
        return new Record(new NDList(new NDArray[]{NDArrays.stack((NDList)data)}), new NDList(new NDArray[]{NDArrays.stack((NDList)target)}));
    }

    protected long availableSize() {
        return this.sourceSentences.size();
    }

    public static class Builder
    extends RandomAccessDataset.BaseBuilder<Builder> {
        private Repository repository;
        private Artifact artifact;
        private Dataset.Usage usage;
        protected List<TextProcessor> sourceTextProcessors = Arrays.asList(new LowerCaseConvertor(Locale.ENGLISH), new PunctuationSeparator(), new SentenceLengthNormalizer(10, false));
        protected List<TextProcessor> targetTextProcessors = Arrays.asList(new LowerCaseConvertor(Locale.FRENCH), new PunctuationSeparator(), new SentenceLengthNormalizer(12, true));
        protected WordEmbedding wordEmbedding;
        protected boolean trainEmbedding;
        protected boolean includeValidLength;
        protected Tokenizer tokenizer;

        Builder() {
            this.repository = BasicDatasets.REPOSITORY;
            this.usage = Dataset.Usage.TRAIN;
        }

        public Builder self() {
            return this;
        }

        public Builder optUsage(Dataset.Usage usage) {
            this.usage = usage;
            return this.self();
        }

        public Builder optRepository(Repository repository) {
            this.repository = repository;
            return this.self();
        }

        public Builder optArtifact(Artifact artifact) {
            this.artifact = artifact;
            return this.self();
        }

        public Builder setEmbedding(WordEmbedding wordEmbedding, boolean trainEmbedding) {
            this.wordEmbedding = wordEmbedding;
            this.trainEmbedding = trainEmbedding;
            return this.self();
        }

        public Builder setValidLength(boolean includeValidLength) {
            this.includeValidLength = includeValidLength;
            return this.self();
        }

        public Builder setTokenizer(Tokenizer tokenizer) {
            this.tokenizer = tokenizer;
            return this.self();
        }

        public Builder optSourceTextProcessors(List<TextProcessor> sourceTextProcessors) {
            this.sourceTextProcessors = sourceTextProcessors;
            return this.self();
        }

        public Builder optSourceTextProcessor(TextProcessor sourceTextProcessor) {
            this.sourceTextProcessors.add(sourceTextProcessor);
            return this.self();
        }

        public Builder optTargetTextProcessors(List<TextProcessor> targetTextProcessors) {
            this.targetTextProcessors = targetTextProcessors;
            return this.self();
        }

        public Builder optTargetTextProcessor(TextProcessor targetTextProcessor) {
            this.targetTextProcessors.add(targetTextProcessor);
            return this.self();
        }

        public TatoebaEnglishFrenchDataset build() {
            return new TatoebaEnglishFrenchDataset(this);
        }
    }
}

