/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.search.llm;

import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModel;
import ai.vespa.llm.LanguageModelException;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.completion.StringPrompt;
import ai.vespa.search.llm.LlmSearcherConfig;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.ComponentId;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
import com.yahoo.search.result.ErrorMessage;
import com.yahoo.search.result.EventStream;
import com.yahoo.search.result.HitGroup;
import com.yahoo.search.searchchain.Execution;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.stream.Collectors;

@Beta
public class LLMSearcher
extends Searcher {
    private static Logger log = Logger.getLogger(LLMSearcher.class.getName());
    private static final String API_KEY_HEADER = "X-LLM-API-KEY";
    private static final String STREAM_PROPERTY = "stream";
    private static final String PROMPT_PROPERTY = "prompt";
    private final String propertyPrefix;
    private final boolean stream;
    private final LanguageModel languageModel;
    private final String languageModelId;

    @Inject
    public LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
        this.stream = config.stream();
        this.languageModelId = config.providerId();
        this.languageModel = this.findLanguageModel(this.languageModelId, languageModels);
        this.propertyPrefix = config.propertyPrefix();
    }

    @Override
    public Result search(Query query, Execution execution) {
        return this.complete(query, (Prompt)StringPrompt.from((String)this.getPrompt(query)));
    }

    private LanguageModel findLanguageModel(String providerId, ComponentRegistry<LanguageModel> languageModels) throws IllegalArgumentException {
        if (languageModels.allComponents().isEmpty()) {
            throw new IllegalArgumentException("No language models were found");
        }
        if (providerId == null || providerId.isEmpty()) {
            Optional entry = languageModels.allComponentsById().entrySet().stream().findFirst();
            if (entry.isEmpty()) {
                throw new IllegalArgumentException("No language models were found");
            }
            log.info("Language model provider was not found in config. Fallback to using first available language model: " + ((Map.Entry)entry.get()).getKey());
            return (LanguageModel)((Map.Entry)entry.get()).getValue();
        }
        LanguageModel languageModel = (LanguageModel)languageModels.getComponent(providerId);
        if (languageModel == null) {
            throw new IllegalArgumentException("No component with id '" + providerId + "' was found. Available LLM components are: " + languageModels.allComponentsById().keySet().stream().map(ComponentId::toString).collect(Collectors.joining(",")));
        }
        return languageModel;
    }

    protected Result complete(Query query, Prompt prompt) {
        InferenceParameters options = new InferenceParameters(this.getApiKeyHeader(query), s -> this.lookupProperty((String)s, query));
        Boolean stream = this.lookupPropertyBool(STREAM_PROPERTY, query, this.stream);
        try {
            return stream != false ? this.completeAsync(query, prompt, options) : this.completeSync(query, prompt, options);
        }
        catch (RejectedExecutionException e) {
            return new Result(query, new ErrorMessage(429, e.getMessage()));
        }
    }

    private boolean shouldAddPrompt(Query query) {
        return query.getTrace().getLevel() >= 1;
    }

    private boolean shouldAddTokenStats(Query query) {
        return query.getTrace().getLevel() >= 1;
    }

    private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
        EventStream eventStream = new EventStream();
        if (this.shouldAddPrompt(query)) {
            eventStream.add(prompt.asString(), PROMPT_PROPERTY);
        }
        TokenStats tokenStats = new TokenStats();
        ((CompletableFuture)this.languageModel.completeAsync(prompt, options, completion -> {
            tokenStats.onToken();
            this.handleCompletion(eventStream, (Completion)completion);
        }).exceptionally(exception -> {
            this.handleException(eventStream, (Throwable)exception);
            eventStream.markComplete();
            return Completion.FinishReason.error;
        })).thenAccept(finishReason -> {
            tokenStats.onCompletion();
            if (this.shouldAddTokenStats(query)) {
                eventStream.add(tokenStats.report(), "stats");
            }
            eventStream.markComplete();
        });
        HitGroup hitGroup = new HitGroup("token_stream");
        hitGroup.add(eventStream);
        return new Result(query, hitGroup);
    }

    private void handleCompletion(EventStream eventStream, Completion completion) {
        if (completion.finishReason() == Completion.FinishReason.error) {
            eventStream.add(completion.text(), "error");
        } else {
            eventStream.add(completion.text());
        }
    }

    private void handleException(EventStream eventStream, Throwable exception) {
        int errorCode = 400;
        if (exception instanceof LanguageModelException) {
            LanguageModelException languageModelException = (LanguageModelException)exception;
            errorCode = languageModelException.code();
        }
        eventStream.error(this.languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
    }

    private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
        EventStream eventStream = new EventStream();
        if (this.shouldAddPrompt(query)) {
            eventStream.add(prompt.asString(), PROMPT_PROPERTY);
        }
        List completions = this.languageModel.complete(prompt, options);
        eventStream.add(((Completion)completions.get(0)).text(), "completion");
        eventStream.markComplete();
        HitGroup hitGroup = new HitGroup("completion");
        hitGroup.add(eventStream);
        return new Result(query, hitGroup);
    }

    public String getPrompt(Query query) {
        String prompt = this.lookupPropertyWithOrWithoutPrefix(PROMPT_PROPERTY, p -> query.properties().getString((String)p));
        if (prompt != null) {
            return prompt;
        }
        prompt = query.getModel().getQueryString();
        if (prompt != null) {
            return prompt;
        }
        throw new IllegalArgumentException("Could not find prompt found for query. Tried looking for '" + this.propertyPrefix + ".prompt', 'prompt' or '@query'.");
    }

    public String getPropertyPrefix() {
        return this.propertyPrefix;
    }

    public String lookupProperty(String property, Query query) {
        String propertyWithPrefix = this.propertyPrefix + "." + property;
        return query.properties().getString(propertyWithPrefix, null);
    }

    public Boolean lookupPropertyBool(String property, Query query, boolean defaultValue) {
        String propertyWithPrefix = this.propertyPrefix + "." + property;
        return query.properties().getBoolean(propertyWithPrefix, defaultValue);
    }

    public String lookupPropertyWithOrWithoutPrefix(String property, Function<String, String> lookup) {
        String value = lookup.apply(this.getPropertyPrefix() + "." + property);
        if (value != null) {
            return value;
        }
        return lookup.apply(property);
    }

    public String getApiKeyHeader(Query query) {
        return this.lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p));
    }

    private static class TokenStats {
        private long start = System.currentTimeMillis();
        private long timeToFirstToken;
        private long timeToLastToken;
        private long tokens = 0L;

        TokenStats() {
        }

        void onToken() {
            if (this.tokens == 0L) {
                this.timeToFirstToken = System.currentTimeMillis() - this.start;
            }
            ++this.tokens;
        }

        void onCompletion() {
            this.timeToLastToken = System.currentTimeMillis() - this.start;
        }

        String report() {
            return "Time to first token: " + this.timeToFirstToken + " ms, Generation time: " + this.timeToLastToken + " ms, Generated tokens: " + this.tokens + " " + String.format("(%.2f tokens/sec)", (double)this.tokens / ((double)this.timeToLastToken / 1000.0));
        }
    }
}

