/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.embedding.huggingface;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HuggingFaceEmbedder
implements Embedder {
    private static final Logger LOG = LoggerFactory.getLogger((String)HuggingFaceEmbedder.class.getName());
    private final String inputIdsName;
    private final String attentionMaskName;
    private final String outputName;
    private final int maxTokens;
    private final HuggingFaceTokenizer tokenizer;
    private final OnnxEvaluator evaluator;

    @Inject
    public HuggingFaceEmbedder(HuggingFaceEmbedderConfig config) throws IOException {
        this.maxTokens = config.transformerMaxTokens();
        this.inputIdsName = config.transformerInputIds();
        this.attentionMaskName = config.transformerAttentionMask();
        this.outputName = config.transformerOutput();
        try {
            ClassLoader tccl = Thread.currentThread().getContextClassLoader();
            try {
                Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader());
                this.tokenizer = HuggingFaceTokenizer.newInstance((Path)Paths.get(config.tokenizerPath().toString(), new String[0]));
            }
            finally {
                Thread.currentThread().setContextClassLoader(tccl);
            }
        }
        catch (IOException e) {
            LOG.info("Could not initialize the tokenizer");
            throw new IOException("Could not initialize the tokenizer.");
        }
        this.evaluator = new OnnxEvaluator(config.transformerModel().toString());
        this.validateModel();
    }

    public void validateModel() {
        Map<String, TensorType> inputs = this.evaluator.getInputInfo();
        this.validateName(inputs, this.inputIdsName, "input");
        this.validateName(inputs, this.attentionMaskName, "input");
        Map<String, TensorType> outputs = this.evaluator.getOutputInfo();
        this.validateName(outputs, this.outputName, "output");
    }

    private void validateName(Map<String, TensorType> types, String name, String type) {
        if (!types.containsKey(name)) {
            throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. Model contains: " + String.join((CharSequence)",", types.keySet()));
        }
    }

    public List<Integer> embed(String s, Embedder.Context context) {
        Encoding encoding = this.tokenizer.encode(s);
        List<Integer> tokenIds = this.longToInteger(encoding.getIds());
        int tokensSize = tokenIds.size();
        if (tokensSize > this.maxTokens) {
            Integer lastElement = tokenIds.get(tokensSize - 1);
            tokenIds = tokenIds.subList(0, this.maxTokens - 1);
            tokenIds.add(lastElement);
        }
        return tokenIds;
    }

    public List<Integer> longToInteger(long[] values) {
        return Arrays.stream(values).boxed().map(Long::intValue).toList();
    }

    public Tensor embed(String s, Embedder.Context context, TensorType tensorType) {
        List<Integer> tokenIds = this.embed(s.toLowerCase(), context);
        return this.embedTokens(tokenIds, tensorType);
    }

    Tensor embedTokens(List<Integer> tokenIds, TensorType tensorType) {
        IndexedTensor inputSequence = this.createTensorRepresentation(tokenIds, "d1");
        Tensor attentionMask = this.createAttentionMask((Tensor)inputSequence);
        Map<String, Tensor> inputs = Map.of(this.inputIdsName, inputSequence.expand("d0"), this.attentionMaskName, attentionMask.expand("d0"));
        Map<String, Tensor> outputs = this.evaluator.evaluate(inputs);
        Tensor tokenEmbeddings = outputs.get(this.outputName);
        Tensor.Builder builder = Tensor.Builder.of((TensorType)tensorType);
        Tensor summedEmbeddings = tokenEmbeddings.sum("d1");
        Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1");
        Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
        int i = 0;
        while ((long)i < (Long)((TensorType.Dimension)tensorType.dimensions().get(0)).size().get()) {
            builder.cell(averaged.get(TensorAddress.of((long[])new long[]{0L, i})), new long[]{i});
            ++i;
        }
        return this.normalize(builder.build(), tensorType);
    }

    Tensor normalize(Tensor embedding, TensorType tensorType) {
        double sumOfSquares = 0.0;
        Tensor.Builder builder = Tensor.Builder.of((TensorType)tensorType);
        int i = 0;
        while ((long)i < (Long)((TensorType.Dimension)tensorType.dimensions().get(0)).size().get()) {
            double item = embedding.get(TensorAddress.of((long[])new long[]{i}));
            sumOfSquares += item * item;
            ++i;
        }
        double magnitude = Math.sqrt(sumOfSquares);
        int i2 = 0;
        while ((long)i2 < (Long)((TensorType.Dimension)tensorType.dimensions().get(0)).size().get()) {
            double value = embedding.get(TensorAddress.of((long[])new long[]{i2}));
            builder.cell(value / magnitude, new long[]{i2});
            ++i2;
        }
        return builder.build();
    }

    private IndexedTensor createTensorRepresentation(List<Integer> input, String dimension) {
        int size = input.size();
        TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, (long)size).build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of((TensorType)type);
        for (int i = 0; i < size; ++i) {
            builder.cell((float)input.get(i).intValue(), new long[]{i});
        }
        return builder.build();
    }

    private Tensor createAttentionMask(Tensor inputSequence) {
        return inputSequence.map(x -> 1.0);
    }
}

