/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.embedding;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.language.huggingface.Encoding;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.huggingface.ModelInfo;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

@Beta
public class ColBertEmbedder
extends AbstractComponent
implements Embedder {
    private final Embedder.Runtime runtime;
    private final String inputIdsName;
    private final String attentionMaskName;
    private final String outputName;
    private final HuggingFaceTokenizer tokenizer;
    private final OnnxEvaluator evaluator;
    private final int maxTransformerTokens;
    private final int maxQueryTokens;
    private final int maxDocumentTokens;
    private final long startSequenceToken;
    private final long endSequenceToken;
    private final long maskSequenceToken;
    private static final Set<Long> PUNCTUATION_TOKEN_IDS = new HashSet<Long>(Arrays.asList(999L, 1000L, 1001L, 1002L, 1003L, 1004L, 1005L, 1006L, 1007L, 1008L, 1009L, 1010L, 1011L, 1012L, 1013L, 1024L, 1025L, 1026L, 1027L, 1028L, 1029L, 1030L, 1031L, 1032L, 1033L, 1034L, 1035L, 1036L, 1063L, 1064L, 1065L, 1066L));

    @Inject
    public ColBertEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, ColBertEmbedderConfig config) {
        this.runtime = runtime;
        this.inputIdsName = config.transformerInputIds();
        this.attentionMaskName = config.transformerAttentionMask();
        this.outputName = config.transformerOutput();
        this.maxTransformerTokens = config.transformerMaxTokens();
        this.maxDocumentTokens = Math.min(config.maxDocumentTokens(), this.maxTransformerTokens);
        this.maxQueryTokens = Math.min(config.maxQueryTokens(), this.maxTransformerTokens);
        this.startSequenceToken = config.transformerStartSequenceToken();
        this.endSequenceToken = config.transformerEndSequenceToken();
        this.maskSequenceToken = config.transformerMaskToken();
        Path tokenizerPath = Paths.get(config.tokenizerPath().toString(), new String[0]);
        HuggingFaceTokenizer.Builder builder = new HuggingFaceTokenizer.Builder().addSpecialTokens(false).addDefaultModel(tokenizerPath).setPadding(false);
        ModelInfo info = HuggingFaceTokenizer.getModelInfo((Path)tokenizerPath);
        if (info.maxLength() == -1 || info.truncation() != ModelInfo.TruncationStrategy.LONGEST_FIRST) {
            int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens() ? info.maxLength() : config.transformerMaxTokens();
            builder.setTruncation(true).setMaxLength(maxLength);
        }
        this.tokenizer = builder.build();
        OnnxEvaluatorOptions onnxOpts = new OnnxEvaluatorOptions();
        if (config.transformerGpuDevice() >= 0) {
            onnxOpts.setGpuDevice(config.transformerGpuDevice());
        }
        onnxOpts.setExecutionMode(config.transformerExecutionMode().toString());
        onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads());
        this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
        this.validateModel();
    }

    public void validateModel() {
        Map<String, TensorType> inputs = this.evaluator.getInputInfo();
        this.validateName(inputs, this.inputIdsName, "input");
        this.validateName(inputs, this.attentionMaskName, "input");
        Map<String, TensorType> outputs = this.evaluator.getOutputInfo();
        this.validateName(outputs, this.outputName, "output");
    }

    private void validateName(Map<String, TensorType> types, String name, String type) {
        if (!types.containsKey(name)) {
            throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. Model contains: " + String.join((CharSequence)",", types.keySet()));
        }
    }

    public List<Integer> embed(String text, Embedder.Context context) {
        throw new UnsupportedOperationException("This embedder only supports embed with tensor type");
    }

    public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
        if (!this.verifyTensorType(tensorType)) {
            throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination. Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType);
        }
        if (context.getDestination().startsWith("query")) {
            return this.embedQuery(text, context, tensorType);
        }
        return this.embedDocument(text, context, tensorType);
    }

    public void deconstruct() {
        this.evaluator.close();
        this.tokenizer.close();
    }

    protected Tensor embedQuery(String text, Embedder.Context context, TensorType tensorType) {
        int i;
        if (tensorType.valueType() == TensorType.Value.INT8) {
            throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type");
        }
        long Q_TOKEN_ID = 1L;
        long start = System.nanoTime();
        Encoding encoding = this.tokenizer.encode(text, context.getLanguage());
        this.runtime.sampleSequenceLength((long)encoding.ids().size(), context);
        List ids = encoding.ids();
        if (ids.size() > this.maxQueryTokens - 3) {
            ids = ids.subList(0, this.maxQueryTokens - 3);
        }
        ArrayList<Long> inputIds = new ArrayList<Long>(this.maxQueryTokens);
        ArrayList<Long> attentionMask = new ArrayList<Long>(this.maxQueryTokens);
        inputIds.add(this.startSequenceToken);
        inputIds.add(Q_TOKEN_ID);
        inputIds.addAll(ids);
        inputIds.add(this.endSequenceToken);
        int length = inputIds.size();
        int padding = this.maxQueryTokens - length;
        for (i = 0; i < padding; ++i) {
            inputIds.add(this.maskSequenceToken);
        }
        for (i = 0; i < length; ++i) {
            attentionMask.add(1L);
        }
        for (i = 0; i < padding; ++i) {
            attentionMask.add(0L);
        }
        IndexedTensor inputIdsTensor = this.createTensorRepresentation(inputIds, "d1");
        IndexedTensor attentionMaskTensor = this.createTensorRepresentation(attentionMask, "d1");
        Map<String, Tensor> inputs = Map.of(this.inputIdsName, inputIdsTensor.expand("d0"), this.attentionMaskName, attentionMaskTensor.expand("d0"));
        Map<String, Tensor> outputs = this.evaluator.evaluate(inputs);
        Tensor tokenEmbeddings = outputs.get(this.outputName);
        IndexedTensor result = (IndexedTensor)tokenEmbeddings.reduce(Reduce.Aggregator.min, new String[]{"d0"});
        int dims = ((Long)((TensorType.Dimension)tensorType.indexedSubtype().dimensions().get(0)).size().get()).intValue();
        if ((long)dims != result.shape()[1]) {
            throw new IllegalArgumentException("Token dimensionality does not match indexed dimensionality of " + dims);
        }
        Tensor resultTensor = ColBertEmbedder.toFloatTensor(result, tensorType, inputIds.size());
        this.runtime.sampleEmbeddingLatency((double)(System.nanoTime() - start) / 1000000.0, context);
        return resultTensor;
    }

    protected Tensor embedDocument(String text, Embedder.Context context, TensorType tensorType) {
        long D_TOKEN_ID = 2L;
        long start = System.nanoTime();
        Encoding encoding = this.tokenizer.encode(text, context.getLanguage());
        this.runtime.sampleSequenceLength((long)encoding.ids().size(), context);
        List<Long> ids = encoding.ids().stream().filter(token -> !PUNCTUATION_TOKEN_IDS.contains(token)).toList();
        if (ids.size() > this.maxDocumentTokens - 3) {
            ids = ids.subList(0, this.maxDocumentTokens - 3);
        }
        ArrayList<Long> inputIds = new ArrayList<Long>(this.maxDocumentTokens);
        ArrayList<Long> attentionMask = new ArrayList<Long>(this.maxDocumentTokens);
        inputIds.add(this.startSequenceToken);
        inputIds.add(D_TOKEN_ID);
        inputIds.addAll(ids);
        inputIds.add(this.endSequenceToken);
        for (int i = 0; i < inputIds.size(); ++i) {
            attentionMask.add(1L);
        }
        IndexedTensor inputIdsTensor = this.createTensorRepresentation(inputIds, "d1");
        IndexedTensor attentionMaskTensor = this.createTensorRepresentation(attentionMask, "d1");
        Map<String, Tensor> inputs = Map.of(this.inputIdsName, inputIdsTensor.expand("d0"), this.attentionMaskName, attentionMaskTensor.expand("d0"));
        Map<String, Tensor> outputs = this.evaluator.evaluate(inputs);
        Tensor tokenEmbeddings = outputs.get(this.outputName);
        IndexedTensor result = (IndexedTensor)tokenEmbeddings.reduce(Reduce.Aggregator.min, new String[]{"d0"});
        int retainedTokens = inputIds.size() - 1;
        Tensor contextualEmbeddings = tensorType.valueType() == TensorType.Value.INT8 ? ColBertEmbedder.toBitTensor(result, tensorType, retainedTokens) : ColBertEmbedder.toFloatTensor(result, tensorType, retainedTokens);
        this.runtime.sampleEmbeddingLatency((double)(System.nanoTime() - start) / 1000000.0, context);
        return contextualEmbeddings;
    }

    public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
        int size = type.indexedSubtype().dimensions().size();
        if (size != 1) {
            throw new IllegalArgumentException("Indexed tensor must have one dimension");
        }
        int wantedDimensionality = ((Long)((TensorType.Dimension)type.indexedSubtype().dimensions().get(0)).size().get()).intValue();
        int resultDimensionality = (int)result.shape()[1];
        if (resultDimensionality != wantedDimensionality) {
            throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality + " + dimensions into tensor with " + wantedDimensionality);
        }
        Tensor.Builder builder = Tensor.Builder.of((TensorType)type);
        for (int token = 0; token < nTokens; ++token) {
            for (int d = 0; d < resultDimensionality; ++d) {
                double value = result.get(TensorAddress.of((long[])new long[]{token, d}));
                builder.cell(TensorAddress.of((long[])new long[]{token, d}), value);
            }
        }
        return builder.build();
    }

    public static Tensor toBitTensor(IndexedTensor result, TensorType type, int nTokens) {
        if (type.valueType() != TensorType.Value.INT8) {
            throw new IllegalArgumentException("Only a int8 tensor type can be the destination of bit packing");
        }
        int size = type.indexedSubtype().dimensions().size();
        if (size != 1) {
            throw new IllegalArgumentException("Indexed tensor must have one dimension");
        }
        int wantedDimensionality = ((Long)((TensorType.Dimension)type.indexedSubtype().dimensions().get(0)).size().get()).intValue();
        int resultDimensionality = (int)result.shape()[1];
        if (resultDimensionality / 8 != wantedDimensionality) {
            throw new IllegalArgumentException("Not possible to pack " + resultDimensionality + " + dimensions into " + wantedDimensionality + " dimensions");
        }
        Tensor.Builder builder = Tensor.Builder.of((TensorType)type);
        for (int token = 0; token < nTokens; ++token) {
            BitSet bitSet = new BitSet(8);
            int key = 0;
            int d = 0;
            while ((long)d < result.shape()[1]) {
                double value = result.get(TensorAddress.of((long[])new long[]{token, d}));
                int bitIndex = 7 - d % 8;
                if (value > 0.0) {
                    bitSet.set(bitIndex);
                } else {
                    bitSet.clear(bitIndex);
                }
                if ((d + 1) % 8 == 0) {
                    byte[] bytes = bitSet.toByteArray();
                    byte packed = bytes.length == 0 ? (byte)0 : bytes[0];
                    builder.cell(TensorAddress.of((long[])new long[]{token, key}), (float)packed);
                    ++key;
                    bitSet = new BitSet(8);
                }
                ++d;
            }
        }
        return builder.build();
    }

    protected boolean verifyTensorType(TensorType target) {
        return target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1;
    }

    private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
        int size = input.size();
        TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, (long)size).build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of((TensorType)type);
        for (int i = 0; i < size; ++i) {
            builder.cell((float)input.get(i).longValue(), new long[]{i});
        }
        return builder.build();
    }
}

