/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.language.wordpiece;

import com.yahoo.component.annotation.Inject;
import com.yahoo.language.Language;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.language.process.Tokenizer;
import com.yahoo.language.simple.SimpleLinguistics;
import com.yahoo.language.tools.Embed;
import com.yahoo.language.wordpiece.Model;
import com.yahoo.language.wordpiece.WordPieceConfig;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.io.File;
import java.nio.file.Path;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class WordPieceEmbedder
implements Embedder,
Segmenter {
    private final Map<Language, Model> models;
    private final Tokenizer tokenizer = new SimpleLinguistics().getTokenizer();

    @Inject
    public WordPieceEmbedder(WordPieceConfig config) {
        this(new Builder(config));
    }

    private WordPieceEmbedder(Builder builder) {
        this.models = builder.getModels().entrySet().stream().map(e -> new Model(builder.getSubwordPrefix(), (Language)e.getKey(), (Path)e.getValue())).collect(Collectors.toUnmodifiableMap(m -> m.language(), m -> m));
        if (this.models.isEmpty()) {
            throw new IllegalArgumentException("WordPieceEmbedder requires at least one model configured");
        }
    }

    public List<String> segment(String text, Language language) {
        return this.resolveModelFrom(language).segment(text, this.tokenizer);
    }

    public List<Integer> embed(String text, Embedder.Context context) {
        return this.resolveModelFrom(context.getLanguage()).embed(text, this.tokenizer);
    }

    public Tensor embed(String text, Embedder.Context context, TensorType type) {
        return Embed.asTensor(text, this, context, type);
    }

    private Model resolveModelFrom(Language language) {
        if (this.models.size() == 1 && this.models.containsKey(Language.UNKNOWN)) {
            return this.models.get(Language.UNKNOWN);
        }
        if (this.models.containsKey(language)) {
            return this.models.get(language);
        }
        throw new IllegalArgumentException("No WordPiece model for language " + String.valueOf(language) + " is configured");
    }

    public static final class Builder {
        private String subwordPrefix = "##";
        private final Map<Language, Path> models = new EnumMap<Language, Path>(Language.class);

        public Builder() {
        }

        public Builder(String defaultModelFile) {
            this.addDefaultModel(new File(defaultModelFile).toPath());
        }

        private Builder(WordPieceConfig config) {
            this.subwordPrefix = config.subwordPrefix();
            for (WordPieceConfig.Model model : config.model()) {
                this.addModel(Language.fromLanguageTag((String)model.language()), model.path());
            }
        }

        public Builder setSubwordPrefix(String prefix) {
            this.subwordPrefix = this.subwordPrefix;
            return this;
        }

        public String getSubwordPrefix() {
            return this.subwordPrefix;
        }

        public void addModel(Language language, Path model) {
            this.models.put(language, model);
        }

        public Builder addDefaultModel(Path model) {
            this.addModel(Language.UNKNOWN, model);
            return this;
        }

        public Map<Language, Path> getModels() {
            return this.models;
        }

        public WordPieceEmbedder build() {
            if (this.models.isEmpty()) {
                throw new IllegalStateException("At least one model must be supplied");
            }
            return new WordPieceEmbedder(this);
        }
    }
}

