/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.llm;

import ai.vespa.llm.GeneratorOptions;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.sentencepiece.SentencePieceEmbedder;
import com.yahoo.llm.GeneratorConfig;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class Generator {
    private static final int TOKEN_EOS = 1;
    private static final String BATCH_DIMENSION = "d0";
    private static final String SEQUENCE_DIMENSION = "d1";
    private final int tokenizerMaxTokens;
    private final String encoderInputIdsName;
    private final String encoderAttentionMaskName;
    private final String encoderOutputName;
    private final String decoderInputIdsName;
    private final String decoderAttentionMaskName;
    private final String decoderEncoderHiddenStateName;
    private final String decoderOutputName;
    private final SentencePieceEmbedder tokenizer;
    private final OnnxEvaluator encoder;
    private final OnnxEvaluator decoder;

    @Inject
    public Generator(GeneratorConfig config) {
        this.tokenizer = new SentencePieceEmbedder.Builder(config.tokenizerModel().toString()).build();
        this.tokenizerMaxTokens = config.tokenizerMaxTokens();
        this.encoderInputIdsName = config.encoderModelInputIdsName();
        this.encoderAttentionMaskName = config.encoderModelAttentionMaskName();
        this.encoderOutputName = config.encoderModelOutputName();
        OnnxEvaluatorOptions encoderOptions = new OnnxEvaluatorOptions();
        encoderOptions.setExecutionMode(config.encoderOnnxExecutionMode().toString());
        encoderOptions.setInterOpThreads(this.modifyThreadCount(config.encoderOnnxInterOpThreads()));
        encoderOptions.setIntraOpThreads(this.modifyThreadCount(config.encoderOnnxIntraOpThreads()));
        this.encoder = new OnnxEvaluator(config.encoderModel().toString(), encoderOptions);
        this.decoderInputIdsName = config.decoderModelInputIdsName();
        this.decoderAttentionMaskName = config.decoderModelAttentionMaskName();
        this.decoderEncoderHiddenStateName = config.decoderModelEncoderHiddenStateName();
        this.decoderOutputName = config.decoderModelOutputName();
        OnnxEvaluatorOptions decoderOptions = new OnnxEvaluatorOptions();
        decoderOptions.setExecutionMode(config.decoderOnnxExecutionMode().toString());
        decoderOptions.setInterOpThreads(this.modifyThreadCount(config.decoderOnnxInterOpThreads()));
        decoderOptions.setIntraOpThreads(this.modifyThreadCount(config.decoderOnnxIntraOpThreads()));
        this.decoder = new OnnxEvaluator(config.decoderModel().toString(), decoderOptions);
        this.validateModels();
    }

    public String generate(String prompt, GeneratorOptions options) {
        return switch (options.getSearchMethod()) {
            case GeneratorOptions.SearchMethod.GREEDY -> this.generateGreedy(prompt, options);
            default -> this.generateNotImplemented(options);
        };
    }

    public String generate(String prompt) {
        return this.generate(prompt, new GeneratorOptions());
    }

    private String generateNotImplemented(GeneratorOptions options) {
        throw new UnsupportedOperationException("Search method '" + options.getSearchMethod() + "' is currently not implemented");
    }

    private String generateGreedy(String prompt, GeneratorOptions options) {
        ArrayList<Integer> generatedTokens = new ArrayList<Integer>();
        generatedTokens.add(0);
        List<Integer> inputTokens = this.tokenize(prompt);
        Tensor encoderInput = Generator.createTensorRepresentation(inputTokens, SEQUENCE_DIMENSION);
        Tensor encoderMask = Generator.createAttentionMask(encoderInput).expand(BATCH_DIMENSION);
        Tensor encoderOutput = this.evaluateEncoder(encoderInput.expand(BATCH_DIMENSION), encoderMask);
        while (generatedTokens.size() < options.getMaxLength()) {
            Tensor decoderInput = Generator.createTensorRepresentation(generatedTokens, SEQUENCE_DIMENSION).expand(BATCH_DIMENSION);
            IndexedTensor logits = this.evaluateDecoder(decoderInput, encoderMask, encoderOutput);
            int nextToken = Generator.findMostProbableToken(logits, generatedTokens.size() - 1, BATCH_DIMENSION, SEQUENCE_DIMENSION);
            generatedTokens.add(nextToken);
        }
        return this.detokenize(generatedTokens);
    }

    private Tensor evaluateEncoder(Tensor input, Tensor mask) {
        Map<String, Tensor> encoderInputs = Map.of(this.encoderInputIdsName, input, this.encoderAttentionMaskName, mask);
        return this.encoder.evaluate(encoderInputs, this.encoderOutputName);
    }

    private IndexedTensor evaluateDecoder(Tensor input, Tensor encoderMask, Tensor encoderOutput) {
        Map<String, Tensor> inputs = Map.of(this.decoderInputIdsName, input, this.decoderAttentionMaskName, encoderMask, this.decoderEncoderHiddenStateName, encoderOutput);
        Tensor output = this.decoder.evaluate(inputs, this.decoderOutputName);
        if (!(output instanceof IndexedTensor)) {
            throw new IllegalArgumentException("Output of decoder model is not an 'IndexedTensor'");
        }
        IndexedTensor indexedTensor = (IndexedTensor)output;
        return indexedTensor;
    }

    private static int findMostProbableToken(IndexedTensor logits, int seqIndex, String batchDim, String seqDim) {
        if (logits.type().rank() != 3) {
            throw new IllegalArgumentException("Expected a tensor with rank 3: batch, sequence, and vocabulary size. Got: " + logits.type());
        }
        IndexedTensor.SubspaceIterator iterator = logits.cellIterator(new PartialAddress.Builder(2).add(batchDim, 0L).add(seqDim, (long)seqIndex).build(), DimensionSizes.of((TensorType)logits.type()));
        Double maxVal = iterator.next().getValue();
        int maxIndex = 0;
        int i = 1;
        while (iterator.hasNext()) {
            Double val = iterator.next().getValue();
            if (val >= maxVal && i != 1) {
                maxVal = val;
                maxIndex = i;
            }
            ++i;
        }
        return maxIndex;
    }

    private List<Integer> tokenize(String text) {
        List<Integer> tokens = this.tokenizer.embed(text, new Embedder.Context("tokenizer"));
        tokens = tokens.size() >= this.tokenizerMaxTokens ? tokens.subList(0, this.tokenizerMaxTokens - 1) : tokens;
        tokens.add(1);
        return tokens;
    }

    private String detokenize(List<Integer> tokens) {
        return this.tokenizer.decode(tokens, new Embedder.Context("tokenizer"), true);
    }

    private static Tensor createTensorRepresentation(List<Integer> tokens, String dimension) {
        int size = tokens.size();
        TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, (long)size).build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of((TensorType)type);
        for (int i = 0; i < size; ++i) {
            builder.cell((float)tokens.get(i).intValue(), new long[]{i});
        }
        return builder.build();
    }

    private static Tensor createAttentionMask(Tensor d) {
        return d.map(x -> x > 0.0 ? 1.0 : 0.0);
    }

    private void validateModels() {
        Map<String, TensorType> inputs = this.encoder.getInputInfo();
        this.validateName(inputs, this.encoderInputIdsName, "input");
        this.validateName(inputs, this.encoderAttentionMaskName, "input");
        Map<String, TensorType> outputs = this.encoder.getOutputInfo();
        this.validateName(outputs, this.encoderOutputName, "output");
        inputs = this.decoder.getInputInfo();
        this.validateName(inputs, this.decoderInputIdsName, "input");
        this.validateName(inputs, this.decoderAttentionMaskName, "input");
        this.validateName(inputs, this.decoderEncoderHiddenStateName, "input");
        outputs = this.decoder.getOutputInfo();
        this.validateName(outputs, this.decoderOutputName, "output");
    }

    private void validateName(Map<String, TensorType> types, String name, String type) {
        if (!types.containsKey(name)) {
            throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. Model contains: " + String.join((CharSequence)",", types.keySet()));
        }
    }

    private int modifyThreadCount(int numThreads) {
        if (numThreads >= 0) {
            return numThreads;
        }
        return Math.max(1, (int)Math.ceil((double)Runtime.getRuntime().availableProcessors() / (double)(-1 * numThreads)));
    }
}

