/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset.utils;

import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.embedding.EmbeddingException;
import ai.djl.modality.nlp.embedding.TextEmbedding;
import ai.djl.modality.nlp.embedding.TrainableTextEmbedding;
import ai.djl.modality.nlp.embedding.TrainableWordEmbedding;
import ai.djl.modality.nlp.preprocess.LowerCaseConvertor;
import ai.djl.modality.nlp.preprocess.PunctuationSeparator;
import ai.djl.modality.nlp.preprocess.SimpleTokenizer;
import ai.djl.modality.nlp.preprocess.TextProcessor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.AbstractBlock;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;

public class TextData {
    private List<NDArray> textEmbeddingList;
    private List<String> rawText;
    private List<TextProcessor> textProcessors;
    private List<String> reservedTokens;
    private TextEmbedding textEmbedding;
    private Vocabulary vocabulary;
    private String unknownToken;
    private int embeddingSize;
    private int size;

    public TextData(Configuration config) {
        this.textProcessors = config.textProcessors;
        this.textEmbedding = config.textEmbedding;
        this.vocabulary = config.vocabulary;
        this.embeddingSize = config.embeddingSize;
        this.unknownToken = config.unknownToken;
        this.reservedTokens = config.reservedTokens;
    }

    public static Configuration getDefaultConfiguration() {
        List<TextProcessor> defaultTextProcessors = Arrays.asList(new SimpleTokenizer(), new LowerCaseConvertor(Locale.ENGLISH), new PunctuationSeparator());
        return new Configuration().setEmbeddingSize(15).setTextProcessors(defaultTextProcessors).setUnknownToken("<unk>").setReservedTokens(Arrays.asList("<bos>", "<eos>", "<pad>"));
    }

    public void preprocess(NDManager manager, List<String> newTextData) throws EmbeddingException {
        this.rawText = newTextData;
        ArrayList<List> textData = new ArrayList<List>();
        for (String string : newTextData) {
            List tokens = Collections.singletonList(string);
            for (TextProcessor processor : this.textProcessors) {
                tokens = processor.preprocess(tokens);
            }
            textData.add(tokens);
        }
        if (this.vocabulary == null) {
            SimpleVocabulary.VocabularyBuilder vocabularyBuilder = new SimpleVocabulary.VocabularyBuilder();
            vocabularyBuilder.optMinFrequency(3).optReservedTokens(this.reservedTokens).optUnknownToken(this.unknownToken);
            for (List tokens : textData) {
                vocabularyBuilder.add(tokens);
            }
            this.vocabulary = vocabularyBuilder.build();
        }
        if (this.textEmbedding == null) {
            this.textEmbedding = new TrainableTextEmbedding(new TrainableWordEmbedding(this.vocabulary, this.embeddingSize));
        }
        this.size = textData.size();
        this.textEmbeddingList = new ArrayList<NDArray>();
        for (int i = 0; i < this.size; ++i) {
            List list = (List)textData.get(i);
            for (int j = 0; j < list.size(); ++j) {
                list.set(j, this.vocabulary.getToken(this.vocabulary.getIndex((String)list.get(j))));
            }
            textData.set(i, list);
            if (this.textEmbedding instanceof AbstractBlock) {
                this.textEmbeddingList.add(manager.create(this.textEmbedding.preprocessTextToEmbed(list)));
                continue;
            }
            this.textEmbeddingList.add(this.textEmbedding.embedText(manager, list));
        }
    }

    public void setTextProcessors(List<TextProcessor> textProcessors) {
        this.textProcessors = textProcessors;
    }

    public void setTextEmbedding(TextEmbedding textEmbedding) {
        this.textEmbedding = textEmbedding;
    }

    public TextEmbedding getTextEmbedding() {
        return this.textEmbedding;
    }

    public void setEmbeddingSize(int embeddingSize) {
        this.embeddingSize = embeddingSize;
    }

    public Vocabulary getVocabulary() {
        if (this.vocabulary == null) {
            throw new IllegalStateException("This method must be called after preprocess is called on this object");
        }
        return this.vocabulary;
    }

    public NDArray getEmbedding(NDManager manager, long index) {
        NDArray embedding = this.textEmbeddingList.get(Math.toIntExact(index)).duplicate();
        embedding.attach(manager);
        return embedding;
    }

    public String getRawText(long index) {
        return this.rawText.get(Math.toIntExact(index));
    }

    public List<String> getProcessedText(long index) {
        List tokens = Collections.singletonList(this.getRawText(index));
        for (TextProcessor processor : this.textProcessors) {
            tokens = processor.preprocess(tokens);
        }
        return tokens;
    }

    public int getSize() {
        return this.size;
    }

    public static final class Configuration {
        private List<TextProcessor> textProcessors;
        private TextEmbedding textEmbedding;
        private Vocabulary vocabulary;
        private Integer embeddingSize;
        private String unknownToken;
        private List<String> reservedTokens;

        public Configuration setTextProcessors(List<TextProcessor> textProcessors) {
            this.textProcessors = textProcessors;
            return this;
        }

        public Configuration setTextEmbedding(TextEmbedding textEmbedding) {
            this.textEmbedding = textEmbedding;
            return this;
        }

        public Configuration setVocabulary(Vocabulary vocabulary) {
            this.vocabulary = vocabulary;
            return this;
        }

        public Configuration setEmbeddingSize(int embeddingSize) {
            this.embeddingSize = embeddingSize;
            return this;
        }

        public Configuration setUnknownToken(String unknownToken) {
            this.unknownToken = unknownToken;
            return this;
        }

        public Configuration setReservedTokens(List<String> reservedTokens) {
            this.reservedTokens = reservedTokens;
            return this;
        }

        public Configuration update(Configuration other) {
            this.textProcessors = other.textProcessors != null ? other.textProcessors : this.textProcessors;
            this.textEmbedding = other.textEmbedding != null ? other.textEmbedding : this.textEmbedding;
            this.vocabulary = other.vocabulary != null ? other.vocabulary : this.vocabulary;
            this.embeddingSize = other.embeddingSize != null ? other.embeddingSize : this.embeddingSize;
            this.unknownToken = other.unknownToken != null ? other.unknownToken : this.unknownToken;
            this.reservedTokens = other.reservedTokens != null ? other.reservedTokens : this.reservedTokens;
            return this;
        }
    }
}

