/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.llm.client.openai;

import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModel;
import ai.vespa.llm.LanguageModelException;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import com.yahoo.api.annotations.Beta;
import com.yahoo.slime.Cursor;
import com.yahoo.slime.Inspector;
import com.yahoo.slime.Slime;
import com.yahoo.slime.SlimeUtils;
import java.io.EOFException;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.http.HttpClient;
import java.net.http.HttpConnectTimeoutException;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.net.http.HttpTimeoutException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@Beta
public class OpenAiClient
implements LanguageModel {
    private static final String DEFAULT_MODEL = "gpt-4o-mini";
    private static final String DATA_FIELD = "data: ";
    private static final int MAX_RETRIES = 3;
    private static final long RETRY_DELAY_MS = 250L;
    private static final String OPTION_MODEL = "model";
    private static final String OPTION_TEMPERATURE = "temperature";
    private static final String OPTION_MAX_TOKENS = "maxTokens";
    private final HttpClient httpClient = HttpClient.newBuilder().connectTimeout(Duration.ofMillis(500L)).build();

    @Override
    public List<Completion> complete(Prompt prompt, InferenceParameters options) {
        try {
            HttpResponse<byte[]> httpResponse = this.httpClient.send(this.toRequest(prompt, options, false), HttpResponse.BodyHandlers.ofByteArray());
            Cursor response = SlimeUtils.jsonToSlime(httpResponse.body()).get();
            if (httpResponse.statusCode() != 200) {
                throw new IllegalArgumentException(SlimeUtils.toJson(response));
            }
            return this.toCompletions(response);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) {
        CompletionContext completionContext = new CompletionContext(prompt, options, consumer);
        this.completeAsyncAttempt(completionContext, 0);
        return completionContext.completionFuture();
    }

    private void completeAsyncAttempt(CompletionContext context, int attempt) {
        try {
            HttpRequest request = this.toRequest(context.prompt(), context.options(), true);
            CompletableFuture<HttpResponse<Stream<String>>> futureResponse = this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofLines()).orTimeout(10L, TimeUnit.SECONDS);
            ((CompletableFuture)futureResponse.thenAccept(response -> this.handleHttpResponse((HttpResponse<Stream<String>>)response, context))).exceptionally(exception -> {
                this.handleHttpException((Throwable)exception, context, attempt);
                return null;
            });
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private void handleHttpResponse(HttpResponse<Stream<String>> response, CompletionContext context) {
        try {
            int responseCode = response.statusCode();
            if (responseCode != 200) {
                throw new LanguageModelException(responseCode, response.body().collect(Collectors.joining()));
            }
            try (Stream<String> lines = response.body();){
                lines.forEach(line -> this.processLine(context, (String)line));
            }
        }
        catch (Exception e) {
            context.completionFuture().completeExceptionally(e);
        }
    }

    private void processLine(CompletionContext context, String line) {
        if (line.startsWith(DATA_FIELD)) {
            Cursor root = SlimeUtils.jsonToSlime(line.substring(DATA_FIELD.length())).get();
            Completion completion = this.toCompletions(root, "delta").get(0);
            context.consumer().accept(completion);
            if (!completion.finishReason().equals((Object)Completion.FinishReason.none)) {
                context.completionFuture().complete(completion.finishReason());
            }
        }
    }

    private void waitBeforeRetry() {
        try {
            TimeUnit.MILLISECONDS.sleep(250L);
        }
        catch (InterruptedException ie) {
            Thread.currentThread().interrupt();
        }
    }

    private boolean shouldRetry(Throwable exception) {
        Throwable cause = exception.getCause();
        if (cause instanceof IOException && cause.getMessage().contains("Connection reset")) {
            return true;
        }
        if (cause instanceof HttpConnectTimeoutException) {
            return true;
        }
        if (cause instanceof HttpTimeoutException) {
            return true;
        }
        return cause instanceof EOFException;
    }

    private void handleHttpException(Throwable exception, CompletionContext context, int attempt) {
        if (this.shouldRetry(exception)) {
            if (attempt < 3) {
                this.waitBeforeRetry();
                this.completeAsyncAttempt(context, attempt + 1);
            } else {
                context.completionFuture().completeExceptionally(new RuntimeException("OpenAI: max retries reached", exception));
            }
        } else {
            context.completionFuture().completeExceptionally(exception);
        }
    }

    private HttpRequest toRequest(Prompt prompt, InferenceParameters options, boolean stream) throws IOException, URISyntaxException {
        Slime slime = new Slime();
        Cursor root = slime.setObject();
        root.setString(OPTION_MODEL, options.get(OPTION_MODEL).orElse(DEFAULT_MODEL));
        root.setBool("stream", stream);
        root.setLong("n", 1L);
        if (options.getDouble(OPTION_TEMPERATURE).isPresent()) {
            root.setDouble(OPTION_TEMPERATURE, (double)options.getDouble(OPTION_TEMPERATURE).get());
        }
        if (options.getInt(OPTION_MAX_TOKENS).isPresent()) {
            root.setLong("max_tokens", (long)options.getInt(OPTION_MAX_TOKENS).get().intValue());
        }
        Cursor messagesArray = root.setArray("messages");
        Cursor messagesObject = messagesArray.addObject();
        messagesObject.setString("role", "user");
        messagesObject.setString("content", prompt.asString());
        String endpoint = options.getEndpoint().orElse("https://api.openai.com/v1/chat/completions");
        return HttpRequest.newBuilder(new URI(endpoint)).header("Content-Type", "application/json").header("Authorization", "Bearer " + options.getApiKey().orElse("")).POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime))).build();
    }

    private List<Completion> toCompletions(Inspector response) {
        return this.toCompletions(response, "message");
    }

    private List<Completion> toCompletions(Inspector response, String field) {
        ArrayList<Completion> completions = new ArrayList<Completion>();
        response.field("choices").traverse((__, choice) -> completions.add(this.toCompletion(choice, field)));
        return completions;
    }

    private Completion toCompletion(Inspector choice, String field) {
        String content = choice.field(field).field("content").asString();
        Completion.FinishReason finishReason = this.toFinishReason(choice.field("finish_reason").asString());
        return new Completion(content, finishReason);
    }

    private Completion.FinishReason toFinishReason(String finishReasonString) {
        return switch (finishReasonString) {
            case "length" -> Completion.FinishReason.length;
            case "stop" -> Completion.FinishReason.stop;
            case "", "null" -> Completion.FinishReason.none;
            default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'");
        };
    }

    private record CompletionContext(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer, CompletableFuture<Completion.FinishReason> completionFuture) {
        CompletionContext(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) {
            this(prompt, options, consumer, new CompletableFuture<Completion.FinishReason>());
        }
    }
}

