/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.huggingface.tokenizers.jni.CharSpan;
import ai.djl.huggingface.translator.PretrainedConfig;
import ai.djl.modality.nlp.translator.NamedEntity;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;
import ai.djl.util.Pair;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;

public class TokenClassificationTranslator
implements Translator<String, NamedEntity[]> {
    private HuggingFaceTokenizer tokenizer;
    private boolean includeTokenTypes;
    private boolean int32;
    private boolean softmax;
    private String aggregationStrategy;
    private Batchifier batchifier;
    private PretrainedConfig config;

    TokenClassificationTranslator(Builder builder) {
        this.tokenizer = builder.tokenizer;
        this.includeTokenTypes = builder.includeTokenTypes;
        this.int32 = builder.int32;
        this.softmax = builder.softmax;
        this.aggregationStrategy = builder.aggregationStrategy;
        this.batchifier = builder.batchifier;
    }

    public Batchifier getBatchifier() {
        return this.batchifier;
    }

    public void prepare(TranslatorContext ctx) throws IOException {
        Path path = ctx.getModel().getModelPath();
        Path file = path.resolve("config.json");
        try (BufferedReader reader = Files.newBufferedReader(file);){
            this.config = (PretrainedConfig)JsonUtils.GSON.fromJson((Reader)reader, PretrainedConfig.class);
        }
    }

    public NDList processInput(TranslatorContext ctx, String input) {
        Encoding encoding = this.tokenizer.encode(input);
        ctx.setAttachment("encoding", (Object)encoding);
        ctx.setAttachment("sentence", (Object)input);
        return encoding.toNDList(ctx.getNDManager(), this.includeTokenTypes, this.int32);
    }

    public NDList batchProcessInput(TranslatorContext ctx, List<String> inputs) {
        NDManager manager = ctx.getNDManager();
        Encoding[] encodings = this.tokenizer.batchEncode(inputs);
        ctx.setAttachment("encodings", (Object)encodings);
        ctx.setAttachment("sentences", inputs);
        NDList[] batch = new NDList[encodings.length];
        for (int i = 0; i < encodings.length; ++i) {
            batch[i] = encodings[i].toNDList(manager, this.includeTokenTypes, this.int32);
        }
        return this.batchifier.batchify(batch);
    }

    public NamedEntity[] processOutput(TranslatorContext ctx, NDList list) {
        Encoding encoding = (Encoding)ctx.getAttachment("encoding");
        String sentence = (String)ctx.getAttachment("sentence");
        return this.toNamedEntities(encoding, list, sentence);
    }

    public List<NamedEntity[]> batchProcessOutput(TranslatorContext ctx, NDList list) {
        NDList[] batch = this.batchifier.unbatchify(list);
        Encoding[] encodings = (Encoding[])ctx.getAttachment("encodings");
        List sentences = (List)ctx.getAttachment("sentences");
        ArrayList<NamedEntity[]> ret = new ArrayList<NamedEntity[]>(batch.length);
        for (int i = 0; i < batch.length; ++i) {
            ret.add(this.toNamedEntities(encodings[i], batch[i], (String)sentences.get(i)));
        }
        return ret;
    }

    public static Builder builder(HuggingFaceTokenizer tokenizer) {
        return new Builder(tokenizer);
    }

    public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arguments) {
        Builder builder = TokenClassificationTranslator.builder(tokenizer);
        builder.configure(arguments);
        return builder;
    }

    private NamedEntity[] toNamedEntities(Encoding encoding, NDList list, String sentence) {
        long[] inputIds = encoding.getIds();
        CharSpan[] offsetMapping = encoding.getCharTokenSpans();
        long[] specialTokenMasks = encoding.getSpecialTokenMask();
        String[] words = encoding.getTokens();
        long[] tokenIds = encoding.getIds();
        NDArray probabilities = (NDArray)list.get(0);
        if (this.softmax) {
            probabilities = probabilities.softmax(1);
        }
        List<NamedEntityEx> entities = new ArrayList<NamedEntityEx>();
        for (int i = 0; i < inputIds.length; ++i) {
            int pos;
            if (specialTokenMasks[i] != 0L) continue;
            NDArray prob = probabilities.get(new long[]{i});
            int start = offsetMapping[i].getStart();
            int end = offsetMapping[i].getEnd();
            boolean isSubWord = false;
            if (start > 0 && ("first".equals(this.aggregationStrategy) || "average".equals(this.aggregationStrategy) || "max".equals(this.aggregationStrategy)) && ((pos = sentence.indexOf(32, start - 1)) < 0 || pos > start)) {
                isSubWord = true;
            }
            NamedEntityEx item = new NamedEntityEx(prob, i, words[i], start, end, tokenIds[i], isSubWord);
            entities.add(item);
        }
        if ("first".equals(this.aggregationStrategy) || "average".equals(this.aggregationStrategy) || "max".equals(this.aggregationStrategy)) {
            entities = this.aggregateWords(entities);
            entities = this.groupEntities(entities);
        } else if ("simple".equals(this.aggregationStrategy)) {
            entities = this.groupEntities(entities);
        }
        return (NamedEntity[])entities.stream().filter(o -> !"O".equals(o.getEntity())).map(NamedEntityEx::toNamedEntity).toArray(NamedEntity[]::new);
    }

    private List<NamedEntityEx> aggregateWords(List<NamedEntityEx> entities) {
        ArrayList<NamedEntityEx> agg = new ArrayList<NamedEntityEx>();
        ArrayList<NamedEntityEx> group = new ArrayList<NamedEntityEx>();
        for (NamedEntityEx entity : entities) {
            if (!entity.isSubWord && !group.isEmpty()) {
                agg.add(this.aggregateWord(group));
                group.clear();
            }
            group.add(entity);
        }
        if (!group.isEmpty()) {
            agg.add(this.aggregateWord(group));
        }
        return agg;
    }

    private NamedEntityEx aggregateWord(List<NamedEntityEx> entities) {
        float score;
        String entityName;
        if (entities.size() == 1) {
            return entities.get(0);
        }
        ArrayList<Long> tokenIds = new ArrayList<Long>();
        for (NamedEntityEx entity : entities) {
            tokenIds.addAll(entity.tokenIds);
        }
        NamedEntityEx first = entities.get(0);
        NamedEntityEx last = entities.get(entities.size() - 1);
        if ("first".equals(this.aggregationStrategy)) {
            entityName = first.getEntity();
            score = first.getScore();
        } else if ("max".equals(this.aggregationStrategy)) {
            NamedEntityEx max = entities.stream().max(Comparator.comparingDouble(NamedEntityEx::getScore)).get();
            entityName = max.getEntity();
            score = max.getScore();
        } else {
            NDArray[] arrays = (NDArray[])entities.stream().map(o -> o.prob).toArray(NDArray[]::new);
            NDList list = new NDList(arrays);
            NDArray array = NDArrays.stack((NDList)list).mean(new int[]{0});
            int entityIdx = (int)array.argMax().getLong(new long[0]);
            entityName = this.config.id2label.get(String.valueOf(entityIdx));
            score = array.getFloat(new long[]{entityIdx});
        }
        return new NamedEntityEx(entityName, score, first.start, last.end, tokenIds);
    }

    private List<NamedEntityEx> groupEntities(List<NamedEntityEx> entities) {
        ArrayList<NamedEntityEx> disaggregateGroup = new ArrayList<NamedEntityEx>();
        ArrayList<NamedEntityEx> entityGroups = new ArrayList<NamedEntityEx>();
        for (NamedEntityEx entity : entities) {
            if (disaggregateGroup.isEmpty()) {
                disaggregateGroup.add(entity);
                continue;
            }
            Pair<String, String> tag = this.getTag(entity.getEntity());
            NamedEntityEx lastEntity = (NamedEntityEx)disaggregateGroup.get(disaggregateGroup.size() - 1);
            Pair<String, String> lastTag = this.getTag(lastEntity.getEntity());
            if (!((String)tag.getValue()).equals(lastTag.getValue()) || "B".equals(tag.getKey())) {
                entityGroups.add(this.groupSubEntities(disaggregateGroup));
                disaggregateGroup.clear();
            }
            disaggregateGroup.add(entity);
        }
        if (!disaggregateGroup.isEmpty()) {
            entityGroups.add(this.groupSubEntities(disaggregateGroup));
        }
        return entityGroups;
    }

    private Pair<String, String> getTag(String entityName) {
        if (entityName.startsWith("B-")) {
            return new Pair((Object)"B", (Object)entityName.substring(2));
        }
        if (entityName.startsWith("I-")) {
            return new Pair((Object)"I", (Object)entityName.substring(2));
        }
        return new Pair((Object)"I", (Object)entityName);
    }

    private NamedEntityEx groupSubEntities(List<NamedEntityEx> entities) {
        ArrayList<Long> tokens = new ArrayList<Long>();
        double[] scores = new double[entities.size()];
        for (int i = 0; i < scores.length; ++i) {
            NamedEntityEx entity = entities.get(i);
            tokens.addAll(entity.tokenIds);
            scores[i] = entity.getScore();
        }
        long[] tokenIds = tokens.stream().mapToLong(Long::longValue).toArray();
        String aggWord = this.tokenizer.decode(tokenIds);
        float aggScore = (float)Arrays.stream(scores).sum() / (float)scores.length;
        NamedEntityEx first = entities.get(0);
        NamedEntityEx last = entities.get(entities.size() - 1);
        String entityName = first.getEntity();
        int pos = entityName.indexOf(45);
        if (pos > 0) {
            entityName = entityName.substring(pos + 1);
        }
        return new NamedEntityEx(entityName, aggScore, aggWord, first.start, last.end);
    }

    public static final class Builder {
        HuggingFaceTokenizer tokenizer;
        boolean includeTokenTypes;
        boolean int32;
        boolean softmax = true;
        String aggregationStrategy;
        Batchifier batchifier = Batchifier.STACK;

        Builder(HuggingFaceTokenizer tokenizer) {
            this.tokenizer = tokenizer;
        }

        public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
            this.includeTokenTypes = includeTokenTypes;
            return this;
        }

        public Builder optInt32(boolean int32) {
            this.int32 = int32;
            return this;
        }

        public Builder optSoftmax(boolean softmax) {
            this.softmax = softmax;
            return this;
        }

        public Builder optBatchifier(Batchifier batchifier) {
            this.batchifier = batchifier;
            return this;
        }

        public Builder optAggregationStrategy(String aggregationStrategy) {
            this.aggregationStrategy = aggregationStrategy;
            return this;
        }

        public void configure(Map<String, ?> arguments) {
            this.optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, (String)"includeTokenTypes"));
            this.optInt32(ArgumentsUtil.booleanValue(arguments, (String)"int32"));
            this.optSoftmax(ArgumentsUtil.booleanValue(arguments, (String)"softmax", (boolean)true));
            this.optAggregationStrategy(ArgumentsUtil.stringValue(arguments, (String)"aggregation_strategy", (String)"none"));
            String batchifierStr = ArgumentsUtil.stringValue(arguments, (String)"batchifier", (String)"stack");
            this.optBatchifier(Batchifier.fromString((String)batchifierStr));
        }

        public TokenClassificationTranslator build() {
            return new TokenClassificationTranslator(this);
        }
    }

    private class NamedEntityEx {
        String entity;
        float score;
        int index;
        String word;
        int start;
        int end;
        List<Long> tokenIds;
        boolean isSubWord;
        NDArray prob;
        private boolean initialized;

        NamedEntityEx(String entity, float score, String word, int start, int end) {
            this.entity = entity;
            this.score = score;
            this.index = -1;
            this.word = word;
            this.start = start;
            this.end = end;
            this.initialized = true;
        }

        NamedEntityEx(String entity, float score, int start, int end, List<Long> tokenIds) {
            this.entity = entity;
            this.score = score;
            this.index = -1;
            this.start = start;
            this.end = end;
            this.tokenIds = tokenIds;
            this.initialized = true;
        }

        NamedEntityEx(NDArray prob, int index, String word, int start, int end, long tokenId, boolean isSubWord) {
            this.prob = prob;
            this.index = index;
            this.word = word;
            this.start = start;
            this.end = end;
            this.tokenIds = Collections.singletonList(tokenId);
            this.isSubWord = isSubWord;
        }

        private void init() {
            if (!this.initialized) {
                int entityIdx = (int)this.prob.argMax().getLong(new long[0]);
                this.entity = ((TokenClassificationTranslator)TokenClassificationTranslator.this).config.id2label.get(String.valueOf(entityIdx));
                this.score = this.prob.getFloat(new long[]{entityIdx});
                this.initialized = true;
            }
        }

        String getEntity() {
            this.init();
            return this.entity;
        }

        float getScore() {
            this.init();
            return this.score;
        }

        NamedEntity toNamedEntity() {
            this.init();
            return new NamedEntity(this.entity, this.score, this.index, this.word, this.start, this.end);
        }
    }
}

