/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.embedding.huggingface;

import ai.vespa.embedding.PoolingStrategy;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.huggingface.Encoding;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;

@Beta
public class HuggingFaceEmbedder
extends AbstractComponent
implements Embedder {
    private final String inputIdsName;
    private final String attentionMaskName;
    private final String tokenTypeIdsName;
    private final String outputName;
    private final boolean normalize;
    private final HuggingFaceTokenizer tokenizer;
    private final OnnxEvaluator evaluator;
    private final PoolingStrategy poolingStrategy;

    @Inject
    public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) {
        this.inputIdsName = config.transformerInputIds();
        this.attentionMaskName = config.transformerAttentionMask();
        this.tokenTypeIdsName = config.transformerTokenTypeIds();
        this.outputName = config.transformerOutput();
        this.normalize = config.normalize();
        this.tokenizer = new HuggingFaceTokenizer.Builder().addSpecialTokens(true).addDefaultModel(Paths.get(config.tokenizerPath().toString(), new String[0])).setTruncation(true).setMaxLength(config.transformerMaxTokens()).build();
        this.poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
        OnnxEvaluatorOptions onnxOpts = new OnnxEvaluatorOptions();
        if (config.transformerGpuDevice() >= 0) {
            onnxOpts.setGpuDevice(config.transformerGpuDevice());
        }
        onnxOpts.setExecutionMode(config.transformerExecutionMode().toString());
        onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads());
        this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
        this.validateModel();
    }

    public void validateModel() {
        Map<String, TensorType> inputs = this.evaluator.getInputInfo();
        this.validateName(inputs, this.inputIdsName, "input");
        this.validateName(inputs, this.attentionMaskName, "input");
        if (!this.tokenTypeIdsName.isEmpty()) {
            this.validateName(inputs, this.tokenTypeIdsName, "input");
        }
        Map<String, TensorType> outputs = this.evaluator.getOutputInfo();
        this.validateName(outputs, this.outputName, "output");
    }

    private void validateName(Map<String, TensorType> types, String name, String type) {
        if (!types.containsKey(name)) {
            throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. Model contains: " + String.join((CharSequence)",", types.keySet()));
        }
    }

    public List<Integer> embed(String s, Embedder.Context context) {
        return this.tokenizer.embed(s, context);
    }

    public void deconstruct() {
        this.evaluator.close();
        this.tokenizer.close();
    }

    public Tensor embed(String s, Embedder.Context context, TensorType tensorType) {
        Encoding encoding = this.tokenizer.encode(s, context.getLanguage());
        IndexedTensor inputSequence = this.createTensorRepresentation(encoding.ids(), "d1");
        IndexedTensor attentionMask = this.createTensorRepresentation(encoding.attentionMask(), "d1");
        IndexedTensor tokenTypeIds = this.tokenTypeIdsName.isEmpty() ? null : this.createTensorRepresentation(encoding.typeIds(), "d1");
        Map<String, Tensor> inputs = this.tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty() ? Map.of(this.inputIdsName, inputSequence.expand("d0"), this.attentionMaskName, attentionMask.expand("d0")) : Map.of(this.inputIdsName, inputSequence.expand("d0"), this.attentionMaskName, attentionMask.expand("d0"), this.tokenTypeIdsName, tokenTypeIds.expand("d0"));
        Map<String, Tensor> outputs = this.evaluator.evaluate(inputs);
        Tensor tokenEmbeddings = outputs.get(this.outputName);
        Tensor.Builder builder = Tensor.Builder.of((TensorType)tensorType);
        Tensor summedEmbeddings = tokenEmbeddings.sum("d1");
        Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1");
        Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
        int i = 0;
        while ((long)i < (Long)((TensorType.Dimension)tensorType.dimensions().get(0)).size().get()) {
            builder.cell(averaged.get(TensorAddress.of((long[])new long[]{0L, i})), new long[]{i});
            ++i;
        }
        Tensor result = builder.build();
        return this.normalize ? this.normalize(result, tensorType) : result;
    }

    Tensor normalize(Tensor embedding, TensorType tensorType) {
        double sumOfSquares = 0.0;
        Tensor.Builder builder = Tensor.Builder.of((TensorType)tensorType);
        int i = 0;
        while ((long)i < (Long)((TensorType.Dimension)tensorType.dimensions().get(0)).size().get()) {
            double item = embedding.get(TensorAddress.of((long[])new long[]{i}));
            sumOfSquares += item * item;
            ++i;
        }
        double magnitude = Math.sqrt(sumOfSquares);
        int i2 = 0;
        while ((long)i2 < (Long)((TensorType.Dimension)tensorType.dimensions().get(0)).size().get()) {
            double value = embedding.get(TensorAddress.of((long[])new long[]{i2}));
            builder.cell(value / magnitude, new long[]{i2});
            ++i2;
        }
        return builder.build();
    }

    private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
        int size = input.size();
        TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, (long)size).build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of((TensorType)type);
        for (int i = 0; i < size; ++i) {
            builder.cell((float)input.get(i).longValue(), new long[]{i});
        }
        return builder.build();
    }
}

