/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.embedding;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.wordpiece.WordPieceEmbedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class BertBaseEmbedder
implements Embedder {
    private static final int TOKEN_CLS = 101;
    private static final int TOKEN_SEP = 102;
    private final int maxTokens;
    private final String inputIdsName;
    private final String attentionMaskName;
    private final String tokenTypeIdsName;
    private final String outputName;
    private final String poolingStrategy;
    private final WordPieceEmbedder tokenizer;
    private final OnnxEvaluator evaluator;

    @Inject
    public BertBaseEmbedder(BertBaseEmbedderConfig config) {
        this.maxTokens = config.transformerMaxTokens();
        this.inputIdsName = config.transformerInputIds();
        this.attentionMaskName = config.transformerAttentionMask();
        this.tokenTypeIdsName = config.transformerTokenTypeIds();
        this.outputName = config.transformerOutput();
        this.poolingStrategy = config.poolingStrategy().toString();
        OnnxEvaluatorOptions options = new OnnxEvaluatorOptions();
        options.setExecutionMode(config.onnxExecutionMode().toString());
        options.setInterOpThreads(this.modifyThreadCount(config.onnxInterOpThreads()));
        options.setIntraOpThreads(this.modifyThreadCount(config.onnxIntraOpThreads()));
        this.tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().toString()).build();
        this.evaluator = new OnnxEvaluator(config.transformerModel().toString(), options);
        this.validateModel();
    }

    private void validateModel() {
        Map<String, TensorType> inputs = this.evaluator.getInputInfo();
        this.validateName(inputs, this.inputIdsName, "input");
        this.validateName(inputs, this.attentionMaskName, "input");
        this.validateName(inputs, this.tokenTypeIdsName, "input");
        Map<String, TensorType> outputs = this.evaluator.getOutputInfo();
        this.validateName(outputs, this.outputName, "output");
    }

    private void validateName(Map<String, TensorType> types, String name, String type) {
        if (!types.containsKey(name)) {
            throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. Model contains: " + String.join((CharSequence)",", types.keySet()));
        }
    }

    public List<Integer> embed(String text, Embedder.Context context) {
        return this.tokenizer.embed(text, context);
    }

    public Tensor embed(String text, Embedder.Context context, TensorType type) {
        if (type.dimensions().size() != 1) {
            throw new IllegalArgumentException("Error in embedding to type '" + type + "': should only have one dimension.");
        }
        if (!((TensorType.Dimension)type.dimensions().get(0)).isIndexed()) {
            throw new IllegalArgumentException("Error in embedding to type '" + type + "': dimension should be indexed.");
        }
        List<Integer> tokens = this.embedWithSeperatorTokens(text, context, this.maxTokens);
        return this.embedTokens(tokens, type);
    }

    Tensor embedTokens(List<Integer> tokens, TensorType type) {
        IndexedTensor inputSequence = this.createTensorRepresentation(tokens, "d1");
        Tensor attentionMask = BertBaseEmbedder.createAttentionMask((Tensor)inputSequence);
        Tensor tokenTypeIds = BertBaseEmbedder.createTokenTypeIds((Tensor)inputSequence);
        Map<String, Tensor> inputs = Map.of(this.inputIdsName, inputSequence.expand("d0"), this.attentionMaskName, attentionMask.expand("d0"), this.tokenTypeIdsName, tokenTypeIds.expand("d0"));
        Map<String, Tensor> outputs = this.evaluator.evaluate(inputs);
        Tensor tokenEmbeddings = outputs.get(this.outputName);
        Tensor.Builder builder = Tensor.Builder.of((TensorType)type);
        if (this.poolingStrategy.equals("mean")) {
            Tensor summedEmbeddings = tokenEmbeddings.sum("d1");
            Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1");
            Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
            int i = 0;
            while ((long)i < (Long)((TensorType.Dimension)type.dimensions().get(0)).size().get()) {
                builder.cell(averaged.get(TensorAddress.of((long[])new long[]{0L, i})), new long[]{i});
                ++i;
            }
        } else {
            int i = 0;
            while ((long)i < (Long)((TensorType.Dimension)type.dimensions().get(0)).size().get()) {
                builder.cell(tokenEmbeddings.get(TensorAddress.of((long[])new long[]{0L, 0L, i})), new long[]{i});
                ++i;
            }
        }
        return builder.build();
    }

    private List<Integer> embedWithSeperatorTokens(String text, Embedder.Context context, int maxLength) {
        List<Integer> tokens = new ArrayList<Integer>();
        tokens.add(101);
        tokens.addAll(this.embed(text, context));
        tokens.add(102);
        if (tokens.size() > maxLength) {
            tokens = tokens.subList(0, maxLength - 1);
            tokens.add(102);
        }
        return tokens;
    }

    private IndexedTensor createTensorRepresentation(List<Integer> input, String dimension) {
        int size = input.size();
        TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, (long)size).build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of((TensorType)type);
        for (int i = 0; i < size; ++i) {
            builder.cell((float)input.get(i).intValue(), new long[]{i});
        }
        return builder.build();
    }

    private static Tensor createAttentionMask(Tensor d) {
        return d.map(x -> x > 0.0 ? 1.0 : 0.0);
    }

    private static Tensor createTokenTypeIds(Tensor d) {
        return d.map(x -> x > 0.0 ? 0.0 : 0.0);
    }

    private int modifyThreadCount(int numThreads) {
        if (numThreads >= 0) {
            return numThreads;
        }
        return Math.max(1, (int)Math.ceil((double)Runtime.getRuntime().availableProcessors() / (double)(-1 * numThreads)));
    }
}

