/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.language.huggingface;

import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.Language;
import com.yahoo.language.huggingface.Encoding;
import com.yahoo.language.huggingface.HuggingFaceTokenizerConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.language.tools.Embed;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.yolean.Exceptions;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collection;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;

@Beta
public class HuggingFaceTokenizer
extends AbstractComponent
implements Embedder,
Segmenter,
AutoCloseable {
    private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models = new EnumMap<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer>(Language.class);

    @Inject
    public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) {
        this(new Builder(cfg));
    }

    private HuggingFaceTokenizer(Builder b) {
        ClassLoader original = Thread.currentThread().getContextClassLoader();
        Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
        try {
            b.models.forEach((language, path) -> this.models.put((Language)language, (ai.djl.huggingface.tokenizers.HuggingFaceTokenizer)Exceptions.uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder().optTokenizerPath(path).optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true).build())));
        }
        finally {
            Thread.currentThread().setContextClassLoader(original);
        }
    }

    public List<Integer> embed(String text, Embedder.Context ctx) {
        ai.djl.huggingface.tokenizers.Encoding encoding = this.resolve(ctx.getLanguage()).encode(text);
        return Arrays.stream(encoding.getIds()).mapToInt(Math::toIntExact).boxed().toList();
    }

    public Tensor embed(String text, Embedder.Context ctx, TensorType type) {
        return Embed.asTensor(text, this, ctx, type);
    }

    public List<String> segment(String input, Language language) {
        return List.of(this.resolve(language).encode(input).getTokens());
    }

    public String decode(List<Integer> tokens, Embedder.Context ctx) {
        return this.resolve(ctx.getLanguage()).decode(HuggingFaceTokenizer.toArray(tokens));
    }

    public Encoding encode(String text) {
        return this.encode(text, Language.UNKNOWN);
    }

    public Encoding encode(String text, Language language) {
        return Encoding.from(this.resolve(language).encode(text));
    }

    public String decode(List<Long> tokens) {
        return this.decode(tokens, Language.UNKNOWN);
    }

    public String decode(List<Long> tokens, Language language) {
        return this.resolve(language).decode(HuggingFaceTokenizer.toArray(tokens));
    }

    @Override
    public void close() {
        this.models.forEach((__, model) -> model.close());
    }

    public void deconstruct() {
        this.close();
    }

    private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) {
        if (this.models.size() == 1 && this.models.containsKey(Language.UNKNOWN)) {
            return this.models.get(Language.UNKNOWN);
        }
        if (this.models.containsKey(language)) {
            return this.models.get(language);
        }
        throw new IllegalArgumentException("No model for language " + language);
    }

    private static long[] toArray(Collection<? extends Number> c) {
        return c.stream().mapToLong(Number::longValue).toArray();
    }

    public static final class Builder {
        private final Map<Language, Path> models = new EnumMap<Language, Path>(Language.class);
        private Boolean addSpecialTokens;

        public Builder() {
        }

        public Builder(HuggingFaceTokenizerConfig cfg) {
            for (HuggingFaceTokenizerConfig.Model model : cfg.model()) {
                this.addModel(Language.fromLanguageTag((String)model.language()), model.path());
            }
            this.addSpecialTokens(cfg.addSpecialTokens());
        }

        public Builder addModel(Language lang, Path path) {
            this.models.put(lang, path);
            return this;
        }

        public Builder addDefaultModel(Path path) {
            return this.addModel(Language.UNKNOWN, path);
        }

        public Builder addSpecialTokens(boolean enabled) {
            this.addSpecialTokens = enabled;
            return this;
        }

        public HuggingFaceTokenizer build() {
            return new HuggingFaceTokenizer(this);
        }
    }
}

