/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.llm.clients;

import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModelException;
import ai.vespa.llm.clients.ConfigurableLanguageModel;
import ai.vespa.llm.clients.LlmClientConfig;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.secret.Secrets;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.openai.client.OpenAIClient;
import com.openai.client.OpenAIClientAsync;
import com.openai.client.okhttp.OpenAIOkHttpClient;
import com.openai.client.okhttp.OpenAIOkHttpClientAsync;
import com.openai.core.JsonValue;
import com.openai.models.ChatModel;
import com.openai.models.ResponseFormatJsonSchema;
import com.openai.models.chat.completions.ChatCompletionCreateParams;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;

@Beta
public class OpenAI
extends ConfigurableLanguageModel {
    private static final String DEFAULT_MODEL = "gpt-4o-mini";
    private static final String DEFAULT_ENDPOINT = "https://api.openai.com/v1/";
    private static final String DEFAULT_API_KEY = "<YOUR_API_KEY>";
    private final Map<String, String> configOptions = new HashMap<String, String>();
    OpenAIClient defaultSyncClient;
    String cachedSyncApiKey;
    String cachedSyncEndpoint;
    OpenAIClientAsync defaultAsyncClient;
    String cachedAsyncApiKey;
    String cachedAsyncEndpoint;

    @Inject
    public OpenAI(LlmClientConfig config, Secrets secretStore) {
        super(config, secretStore);
        if (!config.model().isBlank()) {
            this.configOptions.put("model", config.model());
        }
        if (config.temperature() >= 0.0) {
            this.configOptions.put("temperature", String.valueOf(config.temperature()));
        }
        if (config.maxTokens() >= 0) {
            this.configOptions.put("maxTokens", String.valueOf(config.maxTokens()));
        }
    }

    private InferenceParameters prepareParameters(InferenceParameters parameters) {
        this.setApiKey(parameters);
        this.setEndpoint(parameters);
        return parameters.withDefaultOptions(this.configOptions::get);
    }

    OpenAIClient getSyncClient(String apiKey, String endpoint) {
        if (this.defaultSyncClient != null && apiKey != null && apiKey.equals(this.cachedSyncApiKey) && endpoint != null && endpoint.equals(this.cachedSyncEndpoint)) {
            return this.defaultSyncClient;
        }
        this.defaultSyncClient = OpenAIOkHttpClient.builder().apiKey(apiKey).baseUrl(endpoint).responseValidation(false).build();
        this.cachedSyncApiKey = apiKey;
        this.cachedSyncEndpoint = endpoint;
        return this.defaultSyncClient;
    }

    OpenAIClientAsync getAsyncClient(String apiKey, String endpoint) {
        if (this.defaultAsyncClient != null && apiKey != null && apiKey.equals(this.cachedAsyncApiKey) && endpoint != null && endpoint.equals(this.cachedAsyncEndpoint)) {
            return this.defaultAsyncClient;
        }
        this.defaultAsyncClient = OpenAIOkHttpClientAsync.builder().apiKey(apiKey).baseUrl(endpoint).responseValidation(false).build();
        this.cachedAsyncApiKey = apiKey;
        this.cachedAsyncEndpoint = endpoint;
        return this.defaultAsyncClient;
    }

    public List<Completion> complete(Prompt prompt, InferenceParameters parameters) {
        InferenceParameters preparedParameters = this.prepareParameters(parameters);
        String apiKey = preparedParameters.getApiKey().orElse(DEFAULT_API_KEY);
        String endpoint = preparedParameters.getEndpoint().orElse(DEFAULT_ENDPOINT);
        OpenAIClient client = this.getSyncClient(apiKey, endpoint);
        ChatCompletionCreateParams createParams = this.getChatCompletionCreateParams(preparedParameters, prompt);
        return client.chat().completions().create(createParams).choices().stream().flatMap(choice -> choice.message().content().stream().map(content -> new Completion(content, this.mapFinishReason(choice.finishReason().toString())))).toList();
    }

    public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters parameters, Consumer<Completion> consumer) {
        InferenceParameters preparedParameters = this.prepareParameters(parameters);
        String apiKey = preparedParameters.getApiKey().orElse(DEFAULT_API_KEY);
        String endpoint = preparedParameters.getEndpoint().orElse(DEFAULT_ENDPOINT);
        OpenAIClientAsync client = this.getAsyncClient(apiKey, endpoint);
        ChatCompletionCreateParams createParams = this.getChatCompletionCreateParams(preparedParameters, prompt);
        Completion.FinishReason[] lastFinishReasonHolder = new Completion.FinishReason[]{Completion.FinishReason.stop};
        CompletableFuture<Completion.FinishReason> future = new CompletableFuture<Completion.FinishReason>();
        ((CompletableFuture)client.chat().completions().createStreaming(createParams).subscribe(completion -> completion.choices().stream().flatMap(choice -> {
            choice.finishReason().ifPresent(fr -> {
                lastFinishReasonHolder[0] = this.mapFinishReason(fr.toString());
            });
            return choice.delta().content().stream().map(content -> new Completion(content, choice.finishReason().map(fr -> this.mapFinishReason(fr.toString())).orElse(Completion.FinishReason.none)));
        }).forEach(consumer)).onCompleteFuture().thenAccept(unused -> future.complete(lastFinishReasonHolder[0]))).exceptionally(e -> {
            future.completeExceptionally((Throwable)e);
            return null;
        });
        return future;
    }

    private ChatCompletionCreateParams getChatCompletionCreateParams(InferenceParameters parameters, Prompt prompt) {
        ChatCompletionCreateParams.Builder builder = ChatCompletionCreateParams.builder().model(ChatModel.of((String)parameters.get("model").map(Object::toString).orElse(DEFAULT_MODEL))).addUserMessage(prompt.toString());
        parameters.getInt("maxTokens").ifPresent(arg_0 -> ((ChatCompletionCreateParams.Builder)builder).maxCompletionTokens(arg_0));
        parameters.getDouble("temperature").ifPresent(arg_0 -> ((ChatCompletionCreateParams.Builder)builder).temperature(arg_0));
        parameters.getDouble("topp").ifPresent(arg_0 -> ((ChatCompletionCreateParams.Builder)builder).topP(arg_0));
        parameters.getLong("seed").ifPresent(arg_0 -> ((ChatCompletionCreateParams.Builder)builder).seed(arg_0));
        parameters.getInt("npredict").ifPresent(arg_0 -> ((ChatCompletionCreateParams.Builder)builder).n(arg_0));
        parameters.getDouble("frequencypenalty").ifPresent(arg_0 -> ((ChatCompletionCreateParams.Builder)builder).frequencyPenalty(arg_0));
        parameters.getDouble("precencepenalty").ifPresent(arg_0 -> ((ChatCompletionCreateParams.Builder)builder).presencePenalty(arg_0));
        this.addResponseFormat(parameters, builder);
        return builder.build();
    }

    private void addResponseFormat(InferenceParameters parameters, ChatCompletionCreateParams.Builder builder) {
        parameters.get("json_schema").ifPresent(jsonSchemaStr -> {
            try {
                ObjectMapper mapper = new ObjectMapper();
                Map rawMap = (Map)mapper.readValue(jsonSchemaStr.toString(), (TypeReference)new TypeReference<Map<String, Object>>(){});
                HashMap additionalProps = new HashMap();
                rawMap.forEach((key, value) -> additionalProps.put(key, JsonValue.from((Object)value)));
                ResponseFormatJsonSchema.JsonSchema.Schema schema = ResponseFormatJsonSchema.JsonSchema.Schema.builder().putAllAdditionalProperties(additionalProps).build();
                ResponseFormatJsonSchema jsonFormat = ResponseFormatJsonSchema.builder().jsonSchema(ResponseFormatJsonSchema.JsonSchema.builder().name("structured-output").schema(schema).build()).build();
                builder.responseFormat(jsonFormat);
            }
            catch (Exception e) {
                throw new LanguageModelException(400, "Failed to parse JSON schema:\n" + jsonSchemaStr.toString() + "\n" + e.getMessage(), (Throwable)e);
            }
        });
    }

    private Completion.FinishReason mapFinishReason(String openAiFinishReason) {
        if (openAiFinishReason == null) {
            return Completion.FinishReason.none;
        }
        return switch (openAiFinishReason) {
            case "stop" -> Completion.FinishReason.stop;
            case "length" -> Completion.FinishReason.length;
            case "content_filter" -> Completion.FinishReason.content_filter;
            case "tool_calls" -> Completion.FinishReason.tool_calls;
            case "function_call" -> Completion.FinishReason.function_call;
            case "none" -> Completion.FinishReason.none;
            case "error" -> throw new IllegalStateException("OpenAI-client returned finish_reason=error");
            default -> Completion.FinishReason.other;
        };
    }
}

