/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.language.huggingface;

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.huggingface.tokenizers.jni.LibUtils;
import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.io.IOUtils;
import com.yahoo.language.Language;
import com.yahoo.language.huggingface.Encoding;
import com.yahoo.language.huggingface.ModelInfo;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.language.tools.Embed;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.yolean.Exceptions;
import java.io.File;
import java.nio.file.CopyOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.Arrays;
import java.util.Collection;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

@Beta
public class HuggingFaceTokenizer
extends AbstractComponent
implements Embedder,
Segmenter,
AutoCloseable {
    private final Path tmpDirectory = (Path)Exceptions.uncheck(() -> Files.createTempDirectory("hf-tokenizer-", new FileAttribute[0]));
    private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models = HuggingFaceTokenizer.withContextClassloader(() -> {
        EnumMap models = new EnumMap(Language.class);
        b.models.forEach((language, path) -> models.put((Language)language, (ai.djl.huggingface.tokenizers.HuggingFaceTokenizer)Exceptions.uncheck(() -> {
            Path tokenizerDir;
            if (Files.isDirectory(path, new LinkOption[0])) {
                tokenizerDir = path;
            } else {
                tokenizerDir = Files.createDirectory(this.tmpDirectory.resolve(language.languageCode()), new FileAttribute[0]);
                Files.copy(path, tokenizerDir.resolve("tokenizer.json"), new CopyOption[0]);
            }
            HuggingFaceTokenizer.Builder hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder().optTokenizerPath(tokenizerDir).optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true);
            if (b.maxLength != null) {
                hfb.optMaxLength(b.maxLength.intValue());
                hfb.configure(Map.of("modelMaxLength", b.maxLength > 0 ? b.maxLength : Integer.MAX_VALUE));
            }
            if (b.padding != null) {
                if (b.padding.booleanValue()) {
                    hfb.optPadToMaxLength();
                } else {
                    hfb.optPadding(false);
                }
            }
            if (b.truncation != null) {
                hfb.optTruncation(b.truncation.booleanValue());
            }
            return hfb.build();
        })));
        return models;
    });

    @Inject
    public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) {
        this(new Builder(cfg));
    }

    private HuggingFaceTokenizer(Builder b) {
    }

    public List<Integer> embed(String text, Embedder.Context ctx) {
        ai.djl.huggingface.tokenizers.Encoding encoding = this.resolve(ctx.getLanguage()).encode(text);
        return Arrays.stream(encoding.getIds()).mapToInt(Math::toIntExact).boxed().toList();
    }

    public Tensor embed(String text, Embedder.Context ctx, TensorType type) {
        return Embed.asTensor(text, this, ctx, type);
    }

    public List<String> segment(String input, Language language) {
        return List.of(this.resolve(language).encode(input).getTokens());
    }

    public String decode(List<Integer> tokens, Embedder.Context ctx) {
        return this.resolve(ctx.getLanguage()).decode(HuggingFaceTokenizer.toArray(tokens));
    }

    public Encoding encode(String text) {
        return this.encode(text, Language.UNKNOWN);
    }

    public Encoding encode(String text, Language language) {
        return Encoding.from(this.resolve(language).encode(text));
    }

    public String decode(List<Long> tokens) {
        return this.decode(tokens, Language.UNKNOWN);
    }

    public String decode(List<Long> tokens, Language language) {
        return this.resolve(language).decode(HuggingFaceTokenizer.toArray(tokens));
    }

    @Override
    public void close() {
        this.models.forEach((__, model) -> model.close());
        IOUtils.recursiveDeleteDir((File)this.tmpDirectory.toFile());
    }

    public void deconstruct() {
        this.close();
    }

    public static ModelInfo getModelInfo(Path path) {
        return HuggingFaceTokenizer.withContextClassloader(() -> {
            LibUtils.checkStatus();
            long handle = TokenizersLibrary.LIB.createTokenizerFromString((String)Exceptions.uncheck(() -> Files.readString(path)));
            try {
                ModelInfo modelInfo = new ModelInfo(ModelInfo.TruncationStrategy.fromString(TokenizersLibrary.LIB.getTruncationStrategy(handle)), ModelInfo.PaddingStrategy.fromString(TokenizersLibrary.LIB.getPaddingStrategy(handle)), TokenizersLibrary.LIB.getMaxLength(handle), TokenizersLibrary.LIB.getStride(handle), TokenizersLibrary.LIB.getPadToMultipleOf(handle));
                return modelInfo;
            }
            finally {
                TokenizersLibrary.LIB.deleteTokenizer(handle);
            }
        });
    }

    private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) {
        if (this.models.size() == 1 && this.models.containsKey(Language.UNKNOWN)) {
            return this.models.get(Language.UNKNOWN);
        }
        if (this.models.containsKey(language)) {
            return this.models.get(language);
        }
        throw new IllegalArgumentException("No model for language " + language);
    }

    private static <R> R withContextClassloader(Supplier<R> r) {
        ClassLoader original = Thread.currentThread().getContextClassLoader();
        Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
        try {
            R r2 = r.get();
            return r2;
        }
        finally {
            Thread.currentThread().setContextClassLoader(original);
        }
    }

    private static long[] toArray(Collection<? extends Number> c) {
        return c.stream().mapToLong(Number::longValue).toArray();
    }

    static {
        System.setProperty("OPT_OUT_TRACKING", "true");
    }

    public static final class Builder {
        private final Map<Language, Path> models = new EnumMap<Language, Path>(Language.class);
        private Boolean addSpecialTokens;
        private Integer maxLength;
        private Boolean truncation;
        private Boolean padding;

        public Builder() {
        }

        public Builder(HuggingFaceTokenizerConfig cfg) {
            for (HuggingFaceTokenizerConfig.Model model : cfg.model()) {
                this.addModel(Language.fromLanguageTag((String)model.language()), model.path());
            }
            this.addSpecialTokens(cfg.addSpecialTokens());
            if (cfg.maxLength() != -1) {
                this.setMaxLength(cfg.maxLength());
            }
            switch (cfg.truncation()) {
                case ON: {
                    this.setTruncation(true);
                    break;
                }
                case OFF: {
                    this.setTruncation(false);
                }
            }
            switch (cfg.padding()) {
                case ON: {
                    this.setPadding(true);
                    break;
                }
                case OFF: {
                    this.setPadding(false);
                }
            }
        }

        public Builder addModel(Language lang, Path path) {
            this.models.put(lang, path);
            return this;
        }

        public Builder addDefaultModel(Path path) {
            return this.addModel(Language.UNKNOWN, path);
        }

        public Builder addSpecialTokens(boolean enabled) {
            this.addSpecialTokens = enabled;
            return this;
        }

        public Builder setMaxLength(int length) {
            this.maxLength = length;
            return this;
        }

        public Builder setTruncation(boolean enabled) {
            this.truncation = enabled;
            return this;
        }

        public Builder setPadding(boolean enabled) {
            this.padding = enabled;
            return this;
        }

        public HuggingFaceTokenizer build() {
            return new HuggingFaceTokenizer(this);
        }
    }
}

