/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.mistralai;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationConvention;
import io.micrometer.observation.ObservationRegistry;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.content.Media;
import org.springframework.ai.mistralai.MistralAiChatOptions;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.util.context.ContextView;

public class MistralAiChatModel
implements ChatModel {
    private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
    private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build();
    private final Logger logger = LoggerFactory.getLogger(this.getClass());
    private final MistralAiChatOptions defaultOptions;
    private final MistralAiApi mistralAiApi;
    private final RetryTemplate retryTemplate;
    private final ObservationRegistry observationRegistry;
    private final ToolCallingManager toolCallingManager;
    private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate;
    private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

    public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
        this(mistralAiApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry, (ToolExecutionEligibilityPredicate)new DefaultToolExecutionEligibilityPredicate());
    }

    public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry, ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
        Assert.notNull((Object)mistralAiApi, (String)"mistralAiApi cannot be null");
        Assert.notNull((Object)defaultOptions, (String)"defaultOptions cannot be null");
        Assert.notNull((Object)toolCallingManager, (String)"toolCallingManager cannot be null");
        Assert.notNull((Object)retryTemplate, (String)"retryTemplate cannot be null");
        Assert.notNull((Object)observationRegistry, (String)"observationRegistry cannot be null");
        Assert.notNull((Object)toolExecutionEligibilityPredicate, (String)"toolExecutionEligibilityPredicate cannot be null");
        this.mistralAiApi = mistralAiApi;
        this.defaultOptions = defaultOptions;
        this.toolCallingManager = toolCallingManager;
        this.retryTemplate = retryTemplate;
        this.observationRegistry = observationRegistry;
        this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
    }

    public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) {
        Assert.notNull((Object)result, (String)"Mistral AI ChatCompletion must not be null");
        DefaultUsage usage = MistralAiChatModel.getDefaultUsage(result.usage());
        return ChatResponseMetadata.builder().id(result.id()).model(result.model()).usage((Usage)usage).keyValue("created", (Object)result.created()).build();
    }

    public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result, Usage usage) {
        Assert.notNull((Object)result, (String)"Mistral AI ChatCompletion must not be null");
        return ChatResponseMetadata.builder().id(result.id()).model(result.model()).usage(usage).keyValue("created", (Object)result.created()).build();
    }

    private static DefaultUsage getDefaultUsage(MistralAiApi.Usage usage) {
        return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), (Object)usage);
    }

    public ChatResponse call(Prompt prompt) {
        Prompt requestPrompt = this.buildRequestPrompt(prompt);
        return this.internalCall(requestPrompt, null);
    }

    public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
        MistralAiApi.ChatCompletionRequest request = this.createRequest(prompt, false);
        ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(MistralAiApi.PROVIDER_NAME).build();
        ChatResponse response = (ChatResponse)ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry).observe(() -> {
            ResponseEntity completionEntity = (ResponseEntity)this.retryTemplate.execute(ctx -> this.mistralAiApi.chatCompletionEntity(request));
            MistralAiApi.ChatCompletion chatCompletion = (MistralAiApi.ChatCompletion)completionEntity.getBody();
            if (chatCompletion == null) {
                this.logger.warn("No chat completion returned for prompt: {}", (Object)prompt);
                return new ChatResponse(List.of());
            }
            List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
                Map<String, Object> metadata = Map.of("id", chatCompletion.id() != null ? chatCompletion.id() : "", "index", choice.index(), "role", choice.message().role() != null ? choice.message().role().name() : "", "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
                return this.buildGeneration((MistralAiApi.ChatCompletion.Choice)choice, metadata);
            }).toList();
            DefaultUsage usage = MistralAiChatModel.getDefaultUsage(((MistralAiApi.ChatCompletion)completionEntity.getBody()).usage());
            Usage cumulativeUsage = UsageCalculator.getCumulativeUsage((Usage)usage, (ChatResponse)previousChatResponse);
            ChatResponse chatResponse = new ChatResponse(generations, MistralAiChatModel.from((MistralAiApi.ChatCompletion)completionEntity.getBody(), cumulativeUsage));
            observationContext.setResponse((Object)chatResponse);
            return chatResponse;
        });
        if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
            ToolExecutionResult toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
            if (toolExecutionResult.returnDirect()) {
                return ChatResponse.builder().from(response).generations(ToolExecutionResult.buildGenerations((ToolExecutionResult)toolExecutionResult)).build();
            }
            return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response);
        }
        return response;
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        Prompt requestPrompt = this.buildRequestPrompt(prompt);
        return this.internalStream(requestPrompt, null);
    }

    public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
        return Flux.deferContextual(contextView -> {
            MistralAiApi.ChatCompletionRequest request = this.createRequest(prompt, true);
            ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(MistralAiApi.PROVIDER_NAME).build();
            Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);
            observation.parentObservation((Observation)contextView.getOrDefault((Object)"micrometer.observation", null)).start();
            Flux completionChunks = (Flux)this.retryTemplate.execute(ctx -> this.mistralAiApi.chatCompletionStream(request));
            ConcurrentHashMap roleMap = new ConcurrentHashMap();
            Flux chatResponse = completionChunks.map(this::toChatCompletion).switchMap(chatCompletion -> Mono.just((Object)chatCompletion).map(chatCompletion2 -> {
                try {
                    String id = chatCompletion2.id();
                    List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {
                        if (choice.message().role() != null) {
                            roleMap.putIfAbsent(id, choice.message().role().name());
                        }
                        Map<String, Object> metadata = Map.of("id", chatCompletion2.id(), "role", roleMap.getOrDefault(id, ""), "index", choice.index(), "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
                        return this.buildGeneration((MistralAiApi.ChatCompletion.Choice)choice, metadata);
                    }).toList();
                    if (chatCompletion2.usage() != null) {
                        DefaultUsage usage = MistralAiChatModel.getDefaultUsage(chatCompletion2.usage());
                        Usage cumulativeUsage = UsageCalculator.getCumulativeUsage((Usage)usage, (ChatResponse)previousChatResponse);
                        return new ChatResponse(generations, MistralAiChatModel.from(chatCompletion2, cumulativeUsage));
                    }
                    return new ChatResponse(generations);
                }
                catch (Exception e) {
                    this.logger.error("Error processing chat completion", (Throwable)e);
                    return new ChatResponse(List.of());
                }
            }));
            Flux chatResponseFlux = chatResponse.flatMap(response -> {
                if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
                    return Flux.deferContextual(ctx -> {
                        ToolExecutionResult toolExecutionResult;
                        try {
                            ToolCallReactiveContextHolder.setContext((ContextView)ctx);
                            toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
                        }
                        finally {
                            ToolCallReactiveContextHolder.clearContext();
                        }
                        if (toolExecutionResult.returnDirect()) {
                            return Flux.just((Object)ChatResponse.builder().from(response).generations(ToolExecutionResult.buildGenerations((ToolExecutionResult)toolExecutionResult)).build());
                        }
                        return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), (ChatResponse)response);
                    }).subscribeOn(Schedulers.boundedElastic());
                }
                return Flux.just((Object)response);
            }).doOnError(arg_0 -> ((Observation)observation).error(arg_0)).doFinally(s -> observation.stop()).contextWrite(ctx -> ctx.put((Object)"micrometer.observation", (Object)observation));
            return new MessageAggregator().aggregate(chatResponseFlux, arg_0 -> ((ChatModelObservationContext)observationContext).setResponse(arg_0));
        });
    }

    private Generation buildGeneration(MistralAiApi.ChatCompletion.Choice choice, Map<String, Object> metadata) {
        List toolCalls = choice.message().toolCalls() == null ? List.of() : choice.message().toolCalls().stream().map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", toolCall.function().name(), toolCall.function().arguments())).toList();
        AssistantMessage assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
        String finishReason = choice.finishReason() != null ? choice.finishReason().name() : "";
        ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
        return new Generation(assistantMessage, generationMetadata);
    }

    private MistralAiApi.ChatCompletion toChatCompletion(MistralAiApi.ChatCompletionChunk chunk) {
        List<MistralAiApi.ChatCompletion.Choice> choices = chunk.choices().stream().map(cc -> new MistralAiApi.ChatCompletion.Choice(cc.index(), cc.delta(), cc.finishReason(), cc.logprobs())).toList();
        return new MistralAiApi.ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, chunk.usage());
    }

    Prompt buildRequestPrompt(Prompt prompt) {
        MistralAiChatOptions runtimeOptions = null;
        if (prompt.getOptions() != null) {
            ChatOptions chatOptions = prompt.getOptions();
            if (chatOptions instanceof ToolCallingChatOptions) {
                ToolCallingChatOptions toolCallingChatOptions = (ToolCallingChatOptions)chatOptions;
                runtimeOptions = (MistralAiChatOptions)ModelOptionsUtils.copyToTarget((Object)toolCallingChatOptions, ToolCallingChatOptions.class, MistralAiChatOptions.class);
            } else {
                runtimeOptions = (MistralAiChatOptions)ModelOptionsUtils.copyToTarget((Object)prompt.getOptions(), ChatOptions.class, MistralAiChatOptions.class);
            }
        }
        MistralAiChatOptions requestOptions = (MistralAiChatOptions)ModelOptionsUtils.merge(runtimeOptions, (Object)this.defaultOptions, MistralAiChatOptions.class);
        if (runtimeOptions != null) {
            requestOptions.setInternalToolExecutionEnabled((Boolean)ModelOptionsUtils.mergeOption((Object)runtimeOptions.getInternalToolExecutionEnabled(), (Object)this.defaultOptions.getInternalToolExecutionEnabled()));
            requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames()));
            requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), this.defaultOptions.getToolCallbacks()));
            requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), this.defaultOptions.getToolContext()));
        } else {
            requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
            requestOptions.setToolNames(this.defaultOptions.getToolNames());
            requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
            requestOptions.setToolContext(this.defaultOptions.getToolContext());
        }
        ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
        return new Prompt(prompt.getInstructions(), (ChatOptions)requestOptions);
    }

    MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
        List<MistralAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
            if (message instanceof UserMessage) {
                UserMessage userMessage = (UserMessage)message;
                Object content = message.getText();
                if (!CollectionUtils.isEmpty((Collection)userMessage.getMedia())) {
                    ArrayList<MistralAiApi.ChatCompletionMessage.MediaContent> contentList = new ArrayList<MistralAiApi.ChatCompletionMessage.MediaContent>(List.of(new MistralAiApi.ChatCompletionMessage.MediaContent(message.getText())));
                    contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
                    content = contentList;
                }
                return List.of(new MistralAiApi.ChatCompletionMessage(content, MistralAiApi.ChatCompletionMessage.Role.USER));
            }
            if (message instanceof SystemMessage) {
                SystemMessage systemMessage = (SystemMessage)message;
                return List.of(new MistralAiApi.ChatCompletionMessage(systemMessage.getText(), MistralAiApi.ChatCompletionMessage.Role.SYSTEM));
            }
            if (message instanceof AssistantMessage) {
                AssistantMessage assistantMessage = (AssistantMessage)message;
                List<MistralAiApi.ChatCompletionMessage.ToolCall> toolCalls = null;
                if (!CollectionUtils.isEmpty((Collection)assistantMessage.getToolCalls())) {
                    toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
                        MistralAiApi.ChatCompletionMessage.ChatCompletionFunction function = new MistralAiApi.ChatCompletionMessage.ChatCompletionFunction(toolCall.name(), toolCall.arguments());
                        return new MistralAiApi.ChatCompletionMessage.ToolCall(toolCall.id(), toolCall.type(), function, null);
                    }).toList();
                }
                return List.of(new MistralAiApi.ChatCompletionMessage(assistantMessage.getText(), MistralAiApi.ChatCompletionMessage.Role.ASSISTANT, null, toolCalls, null));
            }
            if (message instanceof ToolResponseMessage) {
                ToolResponseMessage toolResponseMessage = (ToolResponseMessage)message;
                toolResponseMessage.getResponses().forEach(response -> Assert.isTrue((response.id() != null ? 1 : 0) != 0, (String)"ToolResponseMessage must have an id"));
                return toolResponseMessage.getResponses().stream().map(toolResponse -> new MistralAiApi.ChatCompletionMessage(toolResponse.responseData(), MistralAiApi.ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id())).toList();
            }
            throw new IllegalStateException("Unexpected message type: " + String.valueOf(message));
        }).flatMap(Collection::stream).toList();
        MistralAiApi.ChatCompletionRequest request = new MistralAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
        MistralAiChatOptions requestOptions = (MistralAiChatOptions)prompt.getOptions();
        request = (MistralAiApi.ChatCompletionRequest)ModelOptionsUtils.merge((Object)requestOptions, (Object)request, MistralAiApi.ChatCompletionRequest.class);
        List toolDefinitions = this.toolCallingManager.resolveToolDefinitions((ToolCallingChatOptions)requestOptions);
        if (!CollectionUtils.isEmpty((Collection)toolDefinitions)) {
            request = (MistralAiApi.ChatCompletionRequest)ModelOptionsUtils.merge((Object)MistralAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), (Object)request, MistralAiApi.ChatCompletionRequest.class);
        }
        return request;
    }

    private MistralAiApi.ChatCompletionMessage.MediaContent mapToMediaContent(Media media) {
        return new MistralAiApi.ChatCompletionMessage.MediaContent(new MistralAiApi.ChatCompletionMessage.MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData())));
    }

    private String fromMediaData(MimeType mimeType, Object mediaContentData) {
        if (mediaContentData instanceof byte[]) {
            byte[] bytes = (byte[])mediaContentData;
            return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes));
        }
        if (mediaContentData instanceof String) {
            String text = (String)mediaContentData;
            return text;
        }
        throw new IllegalArgumentException("Unsupported media data type: " + mediaContentData.getClass().getSimpleName());
    }

    private List<MistralAiApi.FunctionTool> getFunctionTools(List<ToolDefinition> toolDefinitions) {
        return toolDefinitions.stream().map(toolDefinition -> {
            MistralAiApi.FunctionTool.Function function = new MistralAiApi.FunctionTool.Function(toolDefinition.description(), toolDefinition.name(), toolDefinition.inputSchema());
            return new MistralAiApi.FunctionTool(function);
        }).toList();
    }

    public ChatOptions getDefaultOptions() {
        return MistralAiChatOptions.fromOptions(this.defaultOptions);
    }

    public void setObservationConvention(ChatModelObservationConvention observationConvention) {
        Assert.notNull((Object)observationConvention, (String)"observationConvention cannot be null");
        this.observationConvention = observationConvention;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        private MistralAiApi mistralAiApi;
        private MistralAiChatOptions defaultOptions = MistralAiChatOptions.builder().temperature(0.7).topP(1.0).safePrompt(false).model(MistralAiApi.ChatModel.SMALL.getValue()).build();
        private ToolCallingManager toolCallingManager;
        private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate();
        private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
        private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

        private Builder() {
        }

        public Builder mistralAiApi(MistralAiApi mistralAiApi) {
            this.mistralAiApi = mistralAiApi;
            return this;
        }

        public Builder defaultOptions(MistralAiChatOptions defaultOptions) {
            this.defaultOptions = defaultOptions;
            return this;
        }

        public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
            this.toolCallingManager = toolCallingManager;
            return this;
        }

        public Builder toolExecutionEligibilityPredicate(ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
            this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
            return this;
        }

        public Builder retryTemplate(RetryTemplate retryTemplate) {
            this.retryTemplate = retryTemplate;
            return this;
        }

        public Builder observationRegistry(ObservationRegistry observationRegistry) {
            this.observationRegistry = observationRegistry;
            return this;
        }

        public MistralAiChatModel build() {
            if (this.toolCallingManager != null) {
                return new MistralAiChatModel(this.mistralAiApi, this.defaultOptions, this.toolCallingManager, this.retryTemplate, this.observationRegistry, this.toolExecutionEligibilityPredicate);
            }
            return new MistralAiChatModel(this.mistralAiApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, this.retryTemplate, this.observationRegistry, this.toolExecutionEligibilityPredicate);
        }
    }
}

