/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.fasttext;

import ai.djl.fasttext.FtModel;
import ai.djl.fasttext.FtVocabulary;
import ai.djl.modality.nlp.embedding.WordEmbedding;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import com.github.jfasttext.FastTextWrapper;

public class FtWord2VecWordEmbedding
implements WordEmbedding {
    private FtModel model;
    private FtVocabulary vocabulary;

    public FtWord2VecWordEmbedding(FtModel model, FtVocabulary vocabulary) {
        this.model = model;
        this.vocabulary = vocabulary;
    }

    public boolean vocabularyContains(String word) {
        return true;
    }

    public long preprocessWordToEmbed(String word) {
        return this.vocabulary.getIndex(word);
    }

    public NDArray embedWord(NDArray index) {
        return this.embedWord(index.getManager(), index.toLongArray()[0]);
    }

    public NDArray embedWord(NDManager manager, long index) {
        String word = this.vocabulary.getToken(index);
        FastTextWrapper.RealVector rv = this.model.fta.getVector(word);
        int size = (int)rv.size();
        float[] vec = new float[size];
        for (int i = 0; i < size; ++i) {
            vec[i] = rv.get((long)i);
        }
        return manager.create(vec);
    }

    public String unembedWord(NDArray word) {
        if (!word.isScalar()) {
            throw new IllegalArgumentException("NDArray word must be scalar index");
        }
        return this.vocabulary.getToken(word.toLongArray()[0]);
    }
}

