/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.language.sentencepiece;

import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.Language;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.language.sentencepiece.Model;
import com.yahoo.language.sentencepiece.ResultBuilder;
import com.yahoo.language.sentencepiece.Scoring;
import com.yahoo.language.sentencepiece.SentencePieceAlgorithm;
import com.yahoo.language.sentencepiece.SentencePieceConfig;
import com.yahoo.language.sentencepiece.TokenType;
import com.yahoo.language.tools.Embed;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Beta
public class SentencePieceEmbedder
implements Segmenter,
Embedder {
    private final Map<Language, Model> models;
    private final SentencePieceAlgorithm algorithm;

    @Inject
    public SentencePieceEmbedder(SentencePieceConfig config) {
        this(new Builder(config));
    }

    public SentencePieceEmbedder(Builder builder) {
        this.algorithm = new SentencePieceAlgorithm(builder.getCollapseUnknowns(), builder.getScoring());
        this.models = builder.getModels().entrySet().stream().map(e -> new Model((Language)e.getKey(), (Path)e.getValue())).collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m));
        if (this.models.isEmpty()) {
            throw new IllegalArgumentException("SentencePieceEmbedder requires at least one model configured");
        }
    }

    public List<String> segment(String rawInput, Language language) {
        final String input = this.normalize(rawInput);
        ResultBuilder<List<String>> resultBuilder = new ResultBuilder<List<String>>(new ArrayList()){

            @Override
            public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
                ((List)this.result()).add(input.substring(segmentStart, segmentEnd));
            }
        };
        this.segment(input, language, resultBuilder);
        Collections.reverse((List)resultBuilder.result());
        return (List)resultBuilder.result();
    }

    public List<Integer> embed(String rawInput, Embedder.Context context) {
        ResultBuilder<List<Integer>> resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList()){

            @Override
            public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
                ((List)this.result()).add(segmentEnds[segmentEnd].id);
            }
        };
        this.segment(this.normalize(rawInput), context.getLanguage(), resultBuilder);
        Collections.reverse((List)resultBuilder.result());
        return (List)resultBuilder.result();
    }

    public String decode(List<Integer> tokens, Embedder.Context context) {
        return this.decode(tokens, context, false);
    }

    public String decode(List<Integer> tokens, Embedder.Context context, boolean skipControl) {
        Model model = this.resolveModelFrom(context.getLanguage());
        StringBuilder sb = new StringBuilder();
        for (Integer tokenId : tokens) {
            Model.Token token = model.tokenId2Token.get(tokenId);
            boolean skip = skipControl && token.type() == TokenType.control;
            if (skip) continue;
            sb.append(token.text());
        }
        return this.denormalize(sb.toString());
    }

    public Tensor embed(String rawInput, Embedder.Context context, TensorType type) {
        return Embed.asTensor(rawInput, this, context, type);
    }

    private <RESULTTYPE> void segment(String input, Language language, ResultBuilder<RESULTTYPE> resultBuilder) {
        this.algorithm.segment(input, resultBuilder, this.resolveModelFrom(language));
    }

    private Model resolveModelFrom(Language language) {
        if (this.models.size() == 1 && this.models.containsKey(Language.UNKNOWN)) {
            return this.models.get(Language.UNKNOWN);
        }
        if (this.models.containsKey(language)) {
            return this.models.get(language);
        }
        throw new IllegalArgumentException("No SentencePiece model for language " + language + " is configured");
    }

    public String normalize(String s) {
        StringBuilder b = new StringBuilder(s.length() + 1);
        boolean queuedSpace = true;
        for (int i = 0; i < s.length(); ++i) {
            char c = s.charAt(i);
            if (s.charAt(i) == ' ') {
                queuedSpace = true;
                continue;
            }
            if (queuedSpace) {
                b.append('\u2581');
                queuedSpace = false;
            }
            b.append(c);
        }
        return b.toString();
    }

    public String denormalize(String s) {
        String result = s.replace('\u2581', ' ');
        return result.charAt(0) == ' ' ? result.substring(1) : result;
    }

    public static final class Builder {
        private final Map<Language, Path> models = new EnumMap<Language, Path>(Language.class);
        private boolean collapseUnknowns = true;
        private Scoring scoring = Scoring.fewestSegments;

        public Builder() {
        }

        public Builder(String defaultModelFile) {
            this.addDefaultModel(new File(defaultModelFile).toPath());
        }

        private Builder(SentencePieceConfig config) {
            this.collapseUnknowns = config.collapseUnknowns();
            this.scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments : Scoring.highestScore;
            for (SentencePieceConfig.Model model : config.model()) {
                this.addModel(Language.fromLanguageTag((String)model.language()), model.path());
            }
        }

        public void addModel(Language language, Path model) {
            this.models.put(language, model);
        }

        public Builder addDefaultModel(Path model) {
            this.addModel(Language.UNKNOWN, model);
            return this;
        }

        public Map<Language, Path> getModels() {
            return this.models;
        }

        public Builder setCollapseUnknowns(boolean collapseUnknowns) {
            this.collapseUnknowns = collapseUnknowns;
            return this;
        }

        public boolean getCollapseUnknowns() {
            return this.collapseUnknowns;
        }

        public Builder setScoring(Scoring scoring) {
            this.scoring = scoring;
            return this;
        }

        public Scoring getScoring() {
            return this.scoring;
        }

        public SentencePieceEmbedder build() {
            if (this.models.isEmpty()) {
                throw new IllegalStateException("At least one model must be supplied");
            }
            return new SentencePieceEmbedder(this);
        }
    }
}

