/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.fasttext.zoo.nlp.word_embedding;

import ai.djl.Model;
import ai.djl.fasttext.FtAbstractBlock;
import ai.djl.fasttext.FtModel;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.embedding.WordEmbedding;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.ZooModel;

public class FtWord2VecWordEmbedding
implements WordEmbedding {
    private FtAbstractBlock embedding;
    private Vocabulary vocabulary;

    public FtWord2VecWordEmbedding(Model model, Vocabulary vocabulary) {
        if (model instanceof ZooModel) {
            model = ((ZooModel)model).getWrappedModel();
        }
        if (!(model instanceof FtModel)) {
            throw new IllegalArgumentException("The FtWord2VecWordEmbedding requires an FtModel");
        }
        this.embedding = ((FtModel)model).getBlock();
        this.vocabulary = vocabulary;
    }

    public FtWord2VecWordEmbedding(FtAbstractBlock embedding, Vocabulary vocabulary) {
        this.embedding = embedding;
        this.vocabulary = vocabulary;
    }

    public boolean vocabularyContains(String word) {
        return true;
    }

    public long preprocessWordToEmbed(String word) {
        return this.vocabulary.getIndex(word);
    }

    public NDArray embedWord(NDArray index) {
        return this.embedWord(index.getManager(), index.toLongArray()[0]);
    }

    public NDArray embedWord(NDManager manager, long index) {
        String word = this.vocabulary.getToken(index);
        float[] buf = this.embedding.embedWord(word);
        return manager.create(buf);
    }

    public String unembedWord(NDArray word) {
        if (!word.isScalar()) {
            throw new IllegalArgumentException("NDArray word must be scalar index");
        }
        return this.vocabulary.getToken(word.toLongArray()[0]);
    }
}

