/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp.embedding;

import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.embedding.EmbeddingException;
import ai.djl.modality.nlp.embedding.WordEmbedding;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.core.Embedding;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Optional;

public class TrainableWordEmbedding
extends Embedding<String>
implements WordEmbedding {
    private static final String DEFAULT_UNKNOWN_TOKEN = "<unk>";
    private Vocabulary vocabulary;

    public TrainableWordEmbedding(Builder builder) {
        super(builder);
        this.vocabulary = builder.vocabulary;
    }

    public TrainableWordEmbedding(Vocabulary vocabulary, int embeddingSize) {
        this((Builder)((Builder)((Builder)TrainableWordEmbedding.builder().setVocabulary(vocabulary).setEmbeddingSize(embeddingSize)).optDefaultItem(DEFAULT_UNKNOWN_TOKEN)).optUseDefault(false));
    }

    private TrainableWordEmbedding(NDArray embedding, List<String> items) {
        super(embedding);
        this.fallthroughEmbedding = new Embedding.DefaultItem(DEFAULT_UNKNOWN_TOKEN);
        this.vocabulary = new DefaultVocabulary(items);
    }

    private TrainableWordEmbedding(NDArray embedding, List<String> items, SparseFormat sparseFormat) {
        super(embedding, sparseFormat);
        this.fallthroughEmbedding = new Embedding.DefaultItem(DEFAULT_UNKNOWN_TOKEN);
        this.vocabulary = new DefaultVocabulary(items);
    }

    public static TrainableWordEmbedding fromPretrained(NDArray embedding, List<String> items) {
        return new TrainableWordEmbedding(embedding, items);
    }

    public static TrainableWordEmbedding fromPretrained(NDArray embedding, List<String> items, SparseFormat sparseFormat) {
        return new TrainableWordEmbedding(embedding, items, sparseFormat);
    }

    @Override
    public boolean vocabularyContains(String word) {
        return this.vocabulary.getIndex(word) >= 0L;
    }

    @Override
    public long preprocessWordToEmbed(String word) {
        return this.embed(word);
    }

    @Override
    public NDArray embedWord(NDArray index) throws EmbeddingException {
        throw new UnsupportedOperationException("EmbedWord operation is not supported by this class.");
    }

    @Override
    public String unembedWord(NDArray word) {
        if (!word.isScalar()) {
            throw new IllegalArgumentException("NDArray word must be scalar index");
        }
        long wordIndex = word.toLongArray()[0];
        Optional<String> result = this.unembed(wordIndex);
        if (result.isPresent()) {
            return result.get();
        }
        result = this.fallthroughEmbedding.unembed(wordIndex);
        if (result.isPresent()) {
            return result.get();
        }
        throw new IllegalArgumentException("Failed to unembed word");
    }

    @Override
    public byte[] encode(String input) {
        byte[] encodedInput = input.getBytes(StandardCharsets.UTF_8);
        return encodedInput;
    }

    @Override
    public String decode(byte[] byteArray) {
        return new String(byteArray, StandardCharsets.UTF_8);
    }

    @Override
    public long embed(String item) {
        if (this.vocabularyContains(item)) {
            return this.vocabulary.getIndex(item);
        }
        if (this.fallthroughEmbedding != null) {
            return this.fallthroughEmbedding.embed(item);
        }
        throw new IllegalArgumentException("The provided item was not found");
    }

    @Override
    public Optional<String> unembed(long index) {
        if (index == -1L) {
            if (this.fallthroughEmbedding == null) {
                throw new IllegalArgumentException("Index -1 is reserved for the fallThrough but no fallThrough is found");
            }
            return this.fallthroughEmbedding.unembed(index);
        }
        return Optional.ofNullable(this.vocabulary.getToken(index));
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public boolean hasItem(String item) {
        return false;
    }

    public static class Builder
    extends Embedding.BaseBuilder<String, Builder> {
        private Vocabulary vocabulary;

        Builder() {
            this.embeddingType = String.class;
            this.defaultItem = TrainableWordEmbedding.DEFAULT_UNKNOWN_TOKEN;
        }

        public Builder setVocabulary(Vocabulary vocabulary) {
            this.vocabulary = vocabulary;
            this.numEmbeddings = Math.toIntExact(vocabulary.size());
            return this.self();
        }

        @Override
        protected Builder setType(Class<String> embeddingType) {
            return this.self();
        }

        @Override
        protected Builder self() {
            return this;
        }

        public Builder optUnknownToken(String unknownToken) {
            return (Builder)this.optDefaultItem(unknownToken);
        }

        public TrainableWordEmbedding build() {
            if ((long)this.numEmbeddings != this.vocabulary.size()) {
                throw new IllegalArgumentException("The numEmbeddings is " + this.numEmbeddings + " and the vocabulary has size " + this.vocabulary.size() + " but they should be equal.");
            }
            return new TrainableWordEmbedding(this);
        }
    }
}

