/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.embedding;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.language.huggingface.Encoding;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.huggingface.ModelInfo;
import com.yahoo.language.process.Embedder;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsNode;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

@Beta
public class ColBertEmbedder
extends AbstractComponent
implements Embedder {
    private static final String PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~";
    private final Embedder.Runtime runtime;
    private final String inputIdsName;
    private final String attentionMaskName;
    private final String outputName;
    private final HuggingFaceTokenizer tokenizer;
    private final OnnxEvaluator evaluator;
    private final int maxTransformerTokens;
    private final int maxQueryTokens;
    private final int maxDocumentTokens;
    private final long startSequenceToken;
    private final long endSequenceToken;
    private final long maskSequenceToken;
    private final long padSequenceToken;
    private final long querySequenceToken;
    private final long documentSequenceToken;
    private final Set<Long> skipTokens;

    @Inject
    public ColBertEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, ColBertEmbedderConfig config) {
        this.runtime = runtime;
        this.inputIdsName = config.transformerInputIds();
        this.attentionMaskName = config.transformerAttentionMask();
        this.outputName = config.transformerOutput();
        this.maxTransformerTokens = config.transformerMaxTokens();
        this.maxDocumentTokens = Math.min(config.maxDocumentTokens(), this.maxTransformerTokens);
        this.maxQueryTokens = Math.min(config.maxQueryTokens(), this.maxTransformerTokens);
        this.startSequenceToken = config.transformerStartSequenceToken();
        this.endSequenceToken = config.transformerEndSequenceToken();
        this.maskSequenceToken = config.transformerMaskToken();
        this.padSequenceToken = config.transformerPadToken();
        this.querySequenceToken = config.queryTokenId();
        this.documentSequenceToken = config.documentTokenId();
        Path tokenizerPath = Paths.get(config.tokenizerPath().toString(), new String[0]);
        HuggingFaceTokenizer.Builder builder = new HuggingFaceTokenizer.Builder().addSpecialTokens(false).addDefaultModel(tokenizerPath).setPadding(false);
        ModelInfo info = HuggingFaceTokenizer.getModelInfo((Path)tokenizerPath);
        if (info.maxLength() == -1 || info.truncation() != ModelInfo.TruncationStrategy.LONGEST_FIRST) {
            int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens() ? info.maxLength() : config.transformerMaxTokens();
            builder.setTruncation(true).setMaxLength(maxLength);
        }
        this.tokenizer = builder.build();
        this.skipTokens = new HashSet<Long>();
        PUNCTUATION.chars().forEach(c -> this.skipTokens.addAll(this.tokenizer.encode(Character.toString((char)c), null).ids()));
        OnnxEvaluatorOptions onnxOpts = new OnnxEvaluatorOptions();
        if (config.transformerGpuDevice() >= 0) {
            onnxOpts.setGpuDevice(config.transformerGpuDevice());
        }
        onnxOpts.setExecutionMode(config.transformerExecutionMode().toString());
        onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads());
        this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
        this.validateModel();
    }

    public void validateModel() {
        Map<String, TensorType> inputs = this.evaluator.getInputInfo();
        this.validateName(inputs, this.inputIdsName, "input");
        this.validateName(inputs, this.attentionMaskName, "input");
        Map<String, TensorType> outputs = this.evaluator.getOutputInfo();
        this.validateName(outputs, this.outputName, "output");
    }

    private void validateName(Map<String, TensorType> types, String name, String type) {
        if (!types.containsKey(name)) {
            throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. Model contains: " + String.join((CharSequence)",", types.keySet()));
        }
    }

    public List<Integer> embed(String text, Embedder.Context context) {
        throw new UnsupportedOperationException("This embedder only supports embed with tensor type");
    }

    public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
        if (!this.validTensorType(tensorType)) {
            throw new IllegalArgumentException("Invalid colbert embedder tensor target destination. Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType);
        }
        if (context.getDestination().startsWith("query")) {
            return this.embedQuery(text, context, tensorType);
        }
        return this.embedDocument(text, context, tensorType);
    }

    public void deconstruct() {
        this.evaluator.close();
        this.tokenizer.close();
    }

    protected TransformerInput buildTransformerInput(List<Long> tokens, int maxTokens, boolean isQuery) {
        int i;
        if (!isQuery) {
            tokens = tokens.stream().filter(token -> !this.skipTokens.contains(token)).toList();
        }
        ArrayList<Long> inputIds = new ArrayList<Long>(maxTokens);
        ArrayList<Long> attentionMask = new ArrayList<Long>(maxTokens);
        if (tokens.size() > maxTokens - 3) {
            tokens = tokens.subList(0, maxTokens - 3);
        }
        inputIds.add(this.startSequenceToken);
        inputIds.add(isQuery ? this.querySequenceToken : this.documentSequenceToken);
        inputIds.addAll(tokens);
        inputIds.add(this.endSequenceToken);
        int inputLength = inputIds.size();
        long padTokenId = isQuery ? this.maskSequenceToken : this.padSequenceToken;
        int padding = isQuery ? maxTokens - inputLength : 0;
        for (i = 0; i < padding; ++i) {
            inputIds.add(padTokenId);
        }
        for (i = 0; i < inputLength; ++i) {
            attentionMask.add(1L);
        }
        for (i = 0; i < padding; ++i) {
            attentionMask.add(0L);
        }
        return new TransformerInput(inputIds, attentionMask);
    }

    protected Tensor embedQuery(String text, Embedder.Context context, TensorType tensorType) {
        if (tensorType.valueType() == TensorType.Value.INT8) {
            throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type");
        }
        long start = System.nanoTime();
        Encoding encoding = this.tokenizer.encode(text, context.getLanguage());
        this.runtime.sampleSequenceLength((long)encoding.ids().size(), context);
        TransformerInput input = this.buildTransformerInput(encoding.ids(), this.maxQueryTokens, true);
        IndexedTensor inputIdsTensor = this.createTensorRepresentation(input.inputIds, "d1");
        IndexedTensor attentionMaskTensor = this.createTensorRepresentation(input.attentionMask, "d1");
        Map<String, Tensor> inputs = Map.of(this.inputIdsName, inputIdsTensor.expand("d0"), this.attentionMaskName, attentionMaskTensor.expand("d0"));
        Map<String, Tensor> outputs = this.evaluator.evaluate(inputs);
        Tensor tokenEmbeddings = outputs.get(this.outputName);
        IndexedTensor result = (IndexedTensor)tokenEmbeddings;
        int dims = ((Long)((TensorType.Dimension)tensorType.indexedSubtype().dimensions().get(0)).size().get()).intValue();
        if ((long)dims != result.shape()[2]) {
            throw new IllegalArgumentException("Token vector dimensionality does not match indexed dimensionality of " + dims);
        }
        Tensor resultTensor = ColBertEmbedder.toFloatTensor(result, tensorType, input.inputIds.size());
        this.runtime.sampleEmbeddingLatency((double)(System.nanoTime() - start) / 1000000.0, context);
        return resultTensor;
    }

    protected Tensor embedDocument(String text, Embedder.Context context, TensorType tensorType) {
        long start = System.nanoTime();
        Encoding encoding = this.tokenizer.encode(text, context.getLanguage());
        this.runtime.sampleSequenceLength((long)encoding.ids().size(), context);
        TransformerInput input = this.buildTransformerInput(encoding.ids(), this.maxDocumentTokens, false);
        IndexedTensor inputIdsTensor = this.createTensorRepresentation(input.inputIds, "d1");
        IndexedTensor attentionMaskTensor = this.createTensorRepresentation(input.attentionMask, "d1");
        Map<String, Tensor> inputs = Map.of(this.inputIdsName, inputIdsTensor.expand("d0"), this.attentionMaskName, attentionMaskTensor.expand("d0"));
        Map<String, Tensor> outputs = this.evaluator.evaluate(inputs);
        Tensor tokenEmbeddings = outputs.get(this.outputName);
        IndexedTensor result = (IndexedTensor)tokenEmbeddings;
        int maxTokens = input.inputIds.size();
        Tensor contextualEmbeddings = tensorType.valueType() == TensorType.Value.INT8 ? ColBertEmbedder.toBitTensor(result, tensorType, maxTokens) : ColBertEmbedder.toFloatTensor(result, tensorType, maxTokens);
        this.runtime.sampleEmbeddingLatency((double)(System.nanoTime() - start) / 1000000.0, context);
        return contextualEmbeddings;
    }

    public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
        if (result.shape().length != 3) {
            throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]");
        }
        int size = type.indexedSubtype().dimensions().size();
        if (size != 1) {
            throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
        }
        int wantedDimensionality = ((Long)((TensorType.Dimension)type.indexedSubtype().dimensions().get(0)).size().get()).intValue();
        int resultDimensionality = (int)result.shape()[2];
        if (resultDimensionality != wantedDimensionality) {
            throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality + " dimensions into tensor with " + wantedDimensionality);
        }
        Tensor.Builder builder = Tensor.Builder.of((TensorType)type);
        for (int token = 0; token < nTokens; ++token) {
            int d = 0;
            while (d < resultDimensionality) {
                double value = result.get(new long[]{0L, token, d});
                builder.cell(TensorAddress.of((int[])new int[]{token, d++}), value);
            }
        }
        return builder.build();
    }

    public static Tensor toBitTensor(IndexedTensor result, TensorType type, int nTokens) {
        if (type.valueType() != TensorType.Value.INT8) {
            throw new IllegalArgumentException("Only a int8 tensor type can be the destination of bit packing");
        }
        if (result.shape().length != 3) {
            throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]");
        }
        int size = type.indexedSubtype().dimensions().size();
        if (size != 1) {
            throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
        }
        int wantedDimensionality = ((Long)((TensorType.Dimension)type.indexedSubtype().dimensions().get(0)).size().get()).intValue();
        int resultDimensionality = (int)result.shape()[2];
        if (resultDimensionality != 8 * wantedDimensionality) {
            throw new IllegalArgumentException("Not possible to pack " + resultDimensionality + " + dimensions into " + wantedDimensionality + " dimensions");
        }
        Tensor.Builder builder = Tensor.Builder.of((TensorType)type);
        for (int token = 0; token < nTokens; ++token) {
            BitSet bitSet = new BitSet(8);
            int key = 0;
            int d = 0;
            while ((long)d < result.shape()[2]) {
                double value = result.get(new long[]{0L, token, d});
                int bitIndex = 7 - d % 8;
                if (value > 0.0) {
                    bitSet.set(bitIndex);
                } else {
                    bitSet.clear(bitIndex);
                }
                if ((d + 1) % 8 == 0) {
                    byte[] bytes = bitSet.toByteArray();
                    byte packed = bytes.length == 0 ? (byte)0 : bytes[0];
                    builder.cell(TensorAddress.of((int[])new int[]{token, key++}), (float)packed);
                    bitSet = new BitSet(8);
                }
                ++d;
            }
        }
        return builder.build();
    }

    public Set<Long> getSkipTokens() {
        return this.skipTokens;
    }

    public static Tensor expandBitTensor(Tensor packed) {
        UnpackBitsNode unpacker = new UnpackBitsNode((ExpressionNode)new ReferenceNode("input"), TensorType.Value.FLOAT, "big");
        MapContext context = new MapContext();
        context.put("input", (Value)new TensorValue(packed));
        return unpacker.evaluate((Context)context).asTensor();
    }

    protected boolean validTensorType(TensorType target) {
        return target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1;
    }

    private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
        int size = input.size();
        TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, (long)size).build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of((TensorType)type);
        for (int i = 0; i < size; ++i) {
            builder.cell((float)input.get(i).longValue(), new long[]{i});
        }
        return builder.build();
    }

    public record TransformerInput(List<Long> inputIds, List<Long> attentionMask) {
    }
}

