/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.language.sentencepiece;

import com.google.inject.Inject;
import com.yahoo.api.annotations.Beta;
import com.yahoo.language.Language;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.language.sentencepiece.Model;
import com.yahoo.language.sentencepiece.ResultBuilder;
import com.yahoo.language.sentencepiece.Scoring;
import com.yahoo.language.sentencepiece.SentencePieceAlgorithm;
import com.yahoo.language.sentencepiece.SentencePieceConfig;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Beta
public class SentencePieceEmbedder
implements Segmenter,
Embedder {
    private final Map<Language, Model> models;
    private final SentencePieceAlgorithm algorithm;

    @Inject
    public SentencePieceEmbedder(SentencePieceConfig config) {
        this(new Builder(config));
    }

    public SentencePieceEmbedder(Builder builder) {
        this.algorithm = new SentencePieceAlgorithm(builder.collapseUnknowns, builder.getScoring());
        this.models = builder.getModels().entrySet().stream().map(e -> new Model((Language)e.getKey(), (Path)e.getValue())).collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m));
        if (this.models.isEmpty()) {
            throw new IllegalArgumentException("SentencePieceEmbedder requires at least one model configured");
        }
    }

    public List<String> segment(String rawInput, Language language) {
        final String input = this.normalize(rawInput);
        ResultBuilder<List<String>> resultBuilder = new ResultBuilder<List<String>>(new ArrayList()){

            @Override
            public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
                ((List)this.result()).add(input.substring(segmentStart, segmentEnd));
            }
        };
        this.segment(input, language, resultBuilder);
        Collections.reverse((List)resultBuilder.result());
        return (List)resultBuilder.result();
    }

    public List<Integer> embed(String rawInput, Embedder.Context context) {
        ResultBuilder<List<Integer>> resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList()){

            @Override
            public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
                ((List)this.result()).add(segmentEnds[segmentEnd].id);
            }
        };
        this.segment(this.normalize(rawInput), context.getLanguage(), resultBuilder);
        Collections.reverse((List)resultBuilder.result());
        return (List)resultBuilder.result();
    }

    public Tensor embed(String rawInput, Embedder.Context context, TensorType type) {
        if (type.dimensions().size() == 1 && ((TensorType.Dimension)type.dimensions().get(0)).isIndexed()) {
            List<Integer> values = this.embed(rawInput, context);
            long maxSize = values.size();
            if (((TensorType.Dimension)type.dimensions().get(0)).size().isPresent()) {
                maxSize = Math.min(maxSize, (Long)((TensorType.Dimension)type.dimensions().get(0)).size().get());
            }
            Tensor.Builder builder = Tensor.Builder.of((TensorType)type);
            int i = 0;
            while ((long)i < maxSize) {
                builder.cell((float)values.get(i).intValue(), new long[]{i});
                ++i;
            }
            return builder.build();
        }
        if (type.dimensions().size() == 1 && ((TensorType.Dimension)type.dimensions().get(0)).isMapped()) {
            List<String> values = this.segment(rawInput, context.getLanguage());
            Tensor.Builder builder = Tensor.Builder.of((TensorType)type);
            for (int i = 0; i < values.size(); ++i) {
                builder.cell(TensorAddress.ofLabels((String[])new String[]{values.get(i)}), (float)i);
            }
            return builder.build();
        }
        throw new IllegalArgumentException("Don't know how to embed with SentencePiece into " + type);
    }

    private <RESULTTYPE> void segment(String input, Language language, ResultBuilder<RESULTTYPE> resultBuilder) {
        Model model = this.resolveFrom(language);
        this.algorithm.segment(input, resultBuilder, model);
    }

    private Model resolveFrom(Language language) {
        if (this.models.size() == 1 && this.models.containsKey(Language.UNKNOWN)) {
            return this.models.get(Language.UNKNOWN);
        }
        if (this.models.containsKey(language)) {
            return this.models.get(language);
        }
        throw new IllegalArgumentException("No SentencePiece model for language " + language + " is configured");
    }

    public String normalize(String s) {
        StringBuilder b = new StringBuilder(s.length() + 1);
        boolean queuedSpace = true;
        for (int i = 0; i < s.length(); ++i) {
            char c = s.charAt(i);
            if (s.charAt(i) == ' ') {
                queuedSpace = true;
                continue;
            }
            if (queuedSpace) {
                b.append('\u2581');
                queuedSpace = false;
            }
            b.append(c);
        }
        return b.toString();
    }

    public static class Builder {
        private final Map<Language, Path> models = new HashMap<Language, Path>();
        private boolean collapseUnknowns = true;
        private Scoring scoring = Scoring.fewestSegments;

        public Builder() {
        }

        private Builder(SentencePieceConfig config) {
            this.collapseUnknowns = config.collapseUnknowns();
            this.scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments : Scoring.highestScore;
            for (SentencePieceConfig.Model model : config.model()) {
                this.addModel(Language.fromLanguageTag((String)model.language()), model.path());
            }
        }

        public void addModel(Language language, Path model) {
            this.models.put(language, model);
        }

        public Builder addDefaultModel(Path model) {
            this.addModel(Language.UNKNOWN, model);
            return this;
        }

        public Map<Language, Path> getModels() {
            return this.models;
        }

        public Builder setCollapseUnknowns(boolean collapseUnknowns) {
            this.collapseUnknowns = collapseUnknowns;
            return this;
        }

        public boolean getCollapseUnknowns() {
            return this.collapseUnknowns;
        }

        public Builder setScoring(Scoring scoring) {
            this.scoring = scoring;
            return this;
        }

        public Scoring getScoring() {
            return this.scoring;
        }

        public SentencePieceEmbedder build() {
            if (this.models.isEmpty()) {
                throw new IllegalStateException("At least one model must be supplied");
            }
            return new SentencePieceEmbedder(this);
        }
    }
}

