/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset.utils;

import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.embedding.EmbeddingException;
import ai.djl.modality.nlp.embedding.TextEmbedding;
import ai.djl.modality.nlp.embedding.TrainableTextEmbedding;
import ai.djl.modality.nlp.embedding.TrainableWordEmbedding;
import ai.djl.modality.nlp.preprocess.LowerCaseConvertor;
import ai.djl.modality.nlp.preprocess.PunctuationSeparator;
import ai.djl.modality.nlp.preprocess.SimpleTokenizer;
import ai.djl.modality.nlp.preprocess.TextProcessor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;

public class TextData {
    private List<NDArray> textEmbeddingList;
    private List<String> rawText;
    private List<TextProcessor> textProcessors;
    private List<String> reservedTokens;
    private TextEmbedding textEmbedding;
    private SimpleVocabulary vocabulary;
    private String unknownToken;
    private boolean trainEmbedding;
    private int embeddingSize;
    private int size;

    public TextData(Configuration config) {
        this.textProcessors = config.textProcessors;
        this.textEmbedding = config.textEmbedding;
        this.trainEmbedding = config.trainEmbedding;
        this.embeddingSize = config.embeddingSize;
        this.unknownToken = config.unknownToken;
        this.reservedTokens = config.reservedTokens;
    }

    public static Configuration getDefaultConfiguration() {
        List<TextProcessor> defaultTextProcessors = Arrays.asList(new SimpleTokenizer(), new LowerCaseConvertor(Locale.ENGLISH), new PunctuationSeparator());
        return new Configuration().setEmbeddingSize(15).setTrainEmbedding(false).setTextProcessors(defaultTextProcessors).setUnknownToken("<unk>").setReservedTokens(Arrays.asList("<bos>", "<eos>", "<pad>"));
    }

    public void preprocess(NDManager manager, List<String> newTextData) throws EmbeddingException {
        this.rawText = newTextData;
        SimpleVocabulary.VocabularyBuilder vocabularyBuilder = new SimpleVocabulary.VocabularyBuilder();
        vocabularyBuilder.optMinFrequency(3).optReservedTokens(this.reservedTokens).optUnknownToken(this.unknownToken);
        ArrayList<List> textData = new ArrayList<List>();
        for (String textDatum : newTextData) {
            List tokens = Collections.singletonList(textDatum);
            for (TextProcessor processor : this.textProcessors) {
                tokens = processor.preprocess(tokens);
            }
            vocabularyBuilder.add(tokens);
            textData.add(tokens);
        }
        this.vocabulary = vocabularyBuilder.build();
        if (this.textEmbedding == null) {
            this.textEmbedding = new TrainableTextEmbedding(new TrainableWordEmbedding(this.vocabulary, this.embeddingSize));
        }
        this.size = textData.size();
        this.textEmbeddingList = new ArrayList<NDArray>();
        for (int i = 0; i < this.size; ++i) {
            List tokenizedTextDatum = (List)textData.get(i);
            for (int j = 0; j < tokenizedTextDatum.size(); ++j) {
                if (this.vocabulary.isKnownToken((String)tokenizedTextDatum.get(j))) continue;
                tokenizedTextDatum.set(j, this.vocabulary.getUnknownToken());
            }
            textData.set(i, tokenizedTextDatum);
            if (this.trainEmbedding) {
                this.textEmbeddingList.add(manager.create(this.textEmbedding.preprocessTextToEmbed(tokenizedTextDatum)));
                continue;
            }
            this.textEmbeddingList.add(this.textEmbedding.embedText(manager, tokenizedTextDatum));
        }
    }

    public void setTextProcessors(List<TextProcessor> textProcessors) {
        this.textProcessors = textProcessors;
    }

    public void setTextEmbedding(TextEmbedding textEmbedding) {
        this.textEmbedding = textEmbedding;
    }

    public TextEmbedding getTextEmbedding() {
        return this.textEmbedding;
    }

    public void setEmbeddingSize(int embeddingSize) {
        this.embeddingSize = embeddingSize;
    }

    public SimpleVocabulary getVocabulary() {
        if (this.vocabulary == null) {
            throw new IllegalStateException("This method must be called after preprocess is called on this object");
        }
        return this.vocabulary;
    }

    public NDArray getEmbedding(NDManager manager, long index) {
        NDArray embedding = this.textEmbeddingList.get(Math.toIntExact(index)).duplicate();
        embedding.attach(manager);
        return embedding;
    }

    public String getRawText(long index) {
        return this.rawText.get(Math.toIntExact(index));
    }

    public boolean getTrainEmbedding() {
        return this.trainEmbedding;
    }

    public int getSize() {
        return this.size;
    }

    public static final class Configuration {
        private List<TextProcessor> textProcessors;
        private TextEmbedding textEmbedding;
        private Boolean trainEmbedding;
        private Integer embeddingSize;
        private String unknownToken;
        private List<String> reservedTokens;

        public Configuration setTextProcessors(List<TextProcessor> textProcessors) {
            this.textProcessors = textProcessors;
            return this;
        }

        public Configuration setTextEmbedding(TextEmbedding textEmbedding) {
            this.textEmbedding = textEmbedding;
            return this;
        }

        public Configuration setTrainEmbedding(boolean trainEmbedding) {
            this.trainEmbedding = trainEmbedding;
            return this;
        }

        public Configuration setEmbeddingSize(int embeddingSize) {
            this.embeddingSize = embeddingSize;
            return this;
        }

        public Configuration setUnknownToken(String unknownToken) {
            this.unknownToken = unknownToken;
            return this;
        }

        public Configuration setReservedTokens(List<String> reservedTokens) {
            this.reservedTokens = reservedTokens;
            return this;
        }

        public Configuration update(Configuration other) {
            this.textProcessors = other.textProcessors != null ? other.textProcessors : this.textProcessors;
            this.textEmbedding = other.textEmbedding != null ? other.textEmbedding : this.textEmbedding;
            this.trainEmbedding = other.trainEmbedding != null ? other.trainEmbedding : this.trainEmbedding;
            this.embeddingSize = other.embeddingSize != null ? other.embeddingSize : this.embeddingSize;
            this.unknownToken = other.unknownToken != null ? other.unknownToken : this.unknownToken;
            this.reservedTokens = other.reservedTokens != null ? other.reservedTokens : this.reservedTokens;
            return this;
        }
    }
}

