/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.zoo.nlp.embedding;

import ai.djl.Model;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.core.Embedding;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import java.lang.reflect.Type;
import java.util.Collections;
import java.util.Map;
import java.util.Set;

public class GloveWordEmbeddingTranslatorFactory
implements TranslatorFactory {
    public Set<Pair<Type, Type>> getSupportedTypes() {
        return Collections.singleton(new Pair(String.class, NDList.class));
    }

    public Translator<?, ?> newInstance(Class<?> input, Class<?> output, Model model, Map<String, ?> arguments) throws TranslateException {
        if (!this.isSupported(input, output)) {
            throw new IllegalArgumentException("Unsupported input/output types.");
        }
        String unknownToken = (String)arguments.get("unknownToken");
        return new GloveWordEmbeddingTranslator(unknownToken);
    }

    private static final class GloveWordEmbeddingTranslator
    implements Translator<String, NDList> {
        private String unknownToken;
        private Embedding<String> embedding;

        public GloveWordEmbeddingTranslator(String unknownToken) {
            this.unknownToken = unknownToken;
        }

        public void prepare(NDManager manager, Model model) {
            try {
                this.embedding = (Embedding)model.getBlock();
            }
            catch (ClassCastException e) {
                throw new IllegalArgumentException("The model was not an embedding", e);
            }
        }

        public NDList processOutput(TranslatorContext ctx, NDList list) {
            return list;
        }

        public NDList processInput(TranslatorContext ctx, String input) {
            if (this.embedding.hasItem((Object)input)) {
                return new NDList(new NDArray[]{ctx.getNDManager().create(this.embedding.embed((Object)input))});
            }
            return new NDList(new NDArray[]{ctx.getNDManager().create(this.embedding.embed((Object)this.unknownToken))});
        }

        public Batchifier getBatchifier() {
            return Batchifier.STACK;
        }
    }
}

