/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.embedding;

import ai.vespa.embedding.PoolingStrategy;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.wordpiece.WordPieceEmbedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class BertBaseEmbedder
extends AbstractComponent
implements Embedder {
    private final int maxTokens;
    private final int startSequenceToken;
    private final int endSequenceToken;
    private final String inputIdsName;
    private final String attentionMaskName;
    private final String tokenTypeIdsName;
    private final String outputName;
    private final PoolingStrategy poolingStrategy;
    private final Embedder.Runtime runtime;
    private final WordPieceEmbedder tokenizer;
    private final OnnxEvaluator evaluator;

    @Inject
    public BertBaseEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, BertBaseEmbedderConfig config) {
        this.runtime = runtime;
        this.maxTokens = config.transformerMaxTokens();
        this.startSequenceToken = config.transformerStartSequenceToken();
        this.endSequenceToken = config.transformerEndSequenceToken();
        this.inputIdsName = config.transformerInputIds();
        this.attentionMaskName = config.transformerAttentionMask();
        this.tokenTypeIdsName = config.transformerTokenTypeIds();
        this.outputName = config.transformerOutput();
        this.poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
        OnnxEvaluatorOptions.Builder optionsBuilder = new OnnxEvaluatorOptions.Builder().setExecutionMode(config.onnxExecutionMode().toString()).setThreads(config.onnxInterOpThreads(), config.onnxIntraOpThreads());
        if (config.onnxGpuDevice() >= 0) {
            optionsBuilder.setGpuDevice(config.onnxGpuDevice());
        }
        OnnxEvaluatorOptions options = optionsBuilder.build();
        this.tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().toString()).build();
        this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), options);
        this.validateModel();
    }

    private void validateModel() {
        Map<String, TensorType> inputs = this.evaluator.getInputInfo();
        this.validateName(inputs, this.inputIdsName, "input");
        this.validateName(inputs, this.attentionMaskName, "input");
        if (!"".equals(this.tokenTypeIdsName)) {
            this.validateName(inputs, this.tokenTypeIdsName, "input");
        }
        Map<String, TensorType> outputs = this.evaluator.getOutputInfo();
        this.validateName(outputs, this.outputName, "output");
    }

    private void validateName(Map<String, TensorType> types, String name, String type) {
        if (!types.containsKey(name)) {
            throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. Model contains: " + String.join((CharSequence)",", types.keySet()));
        }
    }

    public List<Integer> embed(String text, Embedder.Context context) {
        long start = System.nanoTime();
        List<Integer> tokens = this.tokenize(text, context);
        this.runtime.sampleSequenceLength((long)tokens.size(), context);
        this.runtime.sampleEmbeddingLatency((double)(System.nanoTime() - start) / 1000000.0, context);
        return tokens;
    }

    public Tensor embed(String text, Embedder.Context context, TensorType type) {
        long start = System.nanoTime();
        if (type.dimensions().size() != 1) {
            throw new IllegalArgumentException("Error in embedding to type '" + String.valueOf(type) + "': should only have one dimension.");
        }
        if (!((TensorType.Dimension)type.dimensions().get(0)).isIndexed()) {
            throw new IllegalArgumentException("Error in embedding to type '" + String.valueOf(type) + "': dimension should be indexed.");
        }
        List<Integer> tokens = this.embedWithSeparatorTokens(text, context, this.maxTokens);
        this.runtime.sampleSequenceLength((long)tokens.size(), context);
        Tensor embedding = this.embedTokens(tokens, type);
        this.runtime.sampleEmbeddingLatency((double)(System.nanoTime() - start) / 1000000.0, context);
        return embedding;
    }

    public void deconstruct() {
        this.evaluator.close();
    }

    private List<Integer> tokenize(String text, Embedder.Context ctx) {
        return this.tokenizer.embed(text, ctx);
    }

    Tensor embedTokens(List<Integer> tokens, TensorType type) {
        IndexedTensor inputSequence = this.createTensorRepresentation(tokens, "d1");
        Tensor attentionMask = BertBaseEmbedder.createAttentionMask((Tensor)inputSequence);
        Tensor tokenTypeIds = BertBaseEmbedder.createTokenTypeIds((Tensor)inputSequence);
        Map<String, Tensor> inputs = !"".equals(this.tokenTypeIdsName) ? Map.of(this.inputIdsName, inputSequence.expand("d0"), this.attentionMaskName, attentionMask.expand("d0"), this.tokenTypeIdsName, tokenTypeIds.expand("d0")) : Map.of(this.inputIdsName, inputSequence.expand("d0"), this.attentionMaskName, attentionMask.expand("d0"));
        Map<String, Tensor> outputs = this.evaluator.evaluate(inputs);
        Tensor tokenEmbeddings = outputs.get(this.outputName);
        return this.poolingStrategy.toSentenceEmbedding(type, tokenEmbeddings, attentionMask);
    }

    private List<Integer> embedWithSeparatorTokens(String text, Embedder.Context context, int maxLength) {
        List<Integer> tokens = new ArrayList<Integer>();
        tokens.add(this.startSequenceToken);
        tokens.addAll(this.tokenize(text, context));
        tokens.add(this.endSequenceToken);
        if (tokens.size() > maxLength) {
            tokens = tokens.subList(0, maxLength - 1);
            tokens.add(this.endSequenceToken);
        }
        return tokens;
    }

    private IndexedTensor createTensorRepresentation(List<Integer> input, String dimension) {
        int size = input.size();
        TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, (long)size).build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of((TensorType)type);
        for (int i = 0; i < size; ++i) {
            builder.cell((float)input.get(i).intValue(), new long[]{i});
        }
        return builder.build();
    }

    private static Tensor createAttentionMask(Tensor d) {
        return d.map(x -> x > 0.0 ? 1.0 : 0.0);
    }

    private static Tensor createTokenTypeIds(Tensor d) {
        return d.map(x -> 0.0);
    }
}

