/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.action.NodeActionWithConfig;
import com.alibaba.cloud.ai.graph.agent.interceptor.InterceptorChain;
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelCallHandler;
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelInterceptor;
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelRequest;
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelResponse;
import com.alibaba.cloud.ai.graph.serializer.AgentInstructionMessage;
import com.alibaba.cloud.ai.graph.utils.TypeRef;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.DefaultChatClient;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.template.TemplateRenderer;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;

public class AgentLlmNode
implements NodeActionWithConfig {
    private static final Logger logger = LoggerFactory.getLogger(AgentLlmNode.class);
    public static final String MODEL_ITERATION_KEY = "_MODEL_ITERATION_";
    private String agentName;
    private List<Advisor> advisors = new ArrayList<Advisor>();
    private List<ToolCallback> toolCallbacks = new ArrayList<ToolCallback>();
    private List<ModelInterceptor> modelInterceptors = new ArrayList<ModelInterceptor>();
    private String outputKey;
    private String outputSchema;
    private ChatClient chatClient;
    private String systemPrompt;
    private TemplateRenderer templateRenderer;
    private String instruction;
    private ToolCallingChatOptions chatOptions;
    private boolean enableReasoningLog;

    public AgentLlmNode(Builder builder) {
        this.agentName = builder.agentName;
        this.outputKey = builder.outputKey;
        this.outputSchema = builder.outputSchema;
        this.systemPrompt = builder.systemPrompt;
        this.instruction = builder.instruction;
        this.templateRenderer = builder.templateRenderer;
        if (builder.advisors != null) {
            this.advisors = builder.advisors;
        }
        if (builder.toolCallbacks != null) {
            this.toolCallbacks = builder.toolCallbacks;
        }
        if (builder.modelInterceptors != null) {
            this.modelInterceptors = builder.modelInterceptors;
        }
        this.chatClient = builder.chatClient;
        this.chatOptions = this.buildChatOptions(builder.chatOptions, this.toolCallbacks);
        this.enableReasoningLog = builder.enableReasoningLog;
    }

    public static Builder builder() {
        return new Builder();
    }

    public void setToolCallbacks(List<ToolCallback> toolCallbacks) {
        this.toolCallbacks = toolCallbacks;
    }

    public void setModelInterceptors(List<ModelInterceptor> modelInterceptors) {
        this.modelInterceptors = modelInterceptors;
    }

    public void setInstruction(String instruction) {
        this.instruction = instruction;
    }

    public void setSystemPrompt(String systemPrompt) {
        this.systemPrompt = systemPrompt;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public Map<String, Object> apply(OverAllState state, RunnableConfig config) throws Exception {
        ModelResponse modelResponse;
        ModelCallHandler chainedHandler;
        ModelCallHandler baseHandler;
        AtomicInteger iterations;
        if (this.enableReasoningLog && logger.isDebugEnabled()) {
            logger.debug("[ThreadId {}] Agent {} start reasoning.", (Object)config.threadId().orElse("$default"), (Object)this.agentName);
        }
        if (!config.context().containsKey(MODEL_ITERATION_KEY)) {
            iterations = new AtomicInteger(0);
            config.context().put(MODEL_ITERATION_KEY, iterations);
        } else {
            iterations = (AtomicInteger)config.context().get(MODEL_ITERATION_KEY);
            iterations.incrementAndGet();
        }
        ArrayList<Message> messages = new ArrayList<UserMessage>();
        if (state.value("messages").isEmpty()) {
            if (!state.value("input").isPresent()) throw new IllegalArgumentException("Either 'instruction' or 'includeContents' must be set for Agent.");
            messages.add((Message)new UserMessage(state.value("input").get().toString()));
        } else {
            messages = (List)state.value("messages").get();
        }
        this.augmentUserMessage(messages, this.outputSchema);
        this.renderTemplatedUserMessage(messages, state.data(), config.metadata());
        ModelRequest.Builder requestBuilder = ModelRequest.builder().messages(messages).options(this.chatOptions != null ? (ToolCallingChatOptions)this.chatOptions.copy() : null).context(config.metadata().orElse(new HashMap()));
        if (this.toolCallbacks != null && !this.toolCallbacks.isEmpty()) {
            ArrayList<String> toolNames = new ArrayList<String>();
            HashMap<String, String> toolDescriptions = new HashMap<String, String>();
            for (ToolCallback callback : this.toolCallbacks) {
                String name = callback.getToolDefinition().name();
                String description = callback.getToolDefinition().description();
                toolNames.add(name);
                if (description == null || description.isEmpty()) continue;
                toolDescriptions.put(name, description);
            }
            requestBuilder.tools(toolNames);
            requestBuilder.toolDescriptions(toolDescriptions);
        }
        if (StringUtils.hasLength((String)this.systemPrompt)) {
            requestBuilder.systemMessage(new SystemMessage(this.systemPrompt));
        }
        if (StringUtils.hasLength((String)this.instruction)) {
            ArrayList<Message> messagesWithInstruction = new ArrayList<Message>();
            messagesWithInstruction.add((Message)new UserMessage(this.instruction));
            messagesWithInstruction.addAll(messages);
            requestBuilder.messages(messagesWithInstruction);
        }
        ModelRequest modelRequest = requestBuilder.build();
        boolean stream = config.metadata("_stream_", (TypeRef)new TypeRef<Boolean>(){}).orElse(true);
        if (stream) {
            baseHandler = request -> {
                try {
                    if (this.enableReasoningLog) {
                        String systemPrompt;
                        String string = systemPrompt = request.getSystemMessage() != null ? request.getSystemMessage().getText() : "";
                        if (logger.isDebugEnabled()) {
                            logger.debug("[ThreadId {}] Agent {} reasoning with system prompt: {}", new Object[]{config.threadId().orElse("$default"), this.agentName, systemPrompt});
                        }
                    }
                    Flux chatResponseFlux = this.buildChatClientRequestSpec(request, config).stream().chatResponse();
                    if (this.enableReasoningLog) {
                        chatResponseFlux = chatResponseFlux.doOnNext(chatResponse -> {
                            if (chatResponse != null && chatResponse.getResult() != null && chatResponse.getResult().getOutput() != null) {
                                if (chatResponse.getResult().getOutput().hasToolCalls()) {
                                    logger.info("[ThreadId {}] Agent {} reasoning round {} streaming output: {}", new Object[]{config.threadId().orElse("$default"), this.agentName, iterations.get(), chatResponse.getResult().getOutput().getToolCalls()});
                                } else {
                                    logger.info("[ThreadId {}] Agent {} reasoning round {} streaming output: {}", new Object[]{config.threadId().orElse("$default"), this.agentName, iterations.get(), chatResponse.getResult().getOutput().getText()});
                                }
                            }
                        });
                    }
                    return ModelResponse.of((Flux<ChatResponse>)chatResponseFlux);
                }
                catch (Exception e) {
                    logger.error("Exception during streaming model call: ", (Throwable)e);
                    return ModelResponse.of(new AssistantMessage("Exception: " + e.getMessage()));
                }
            };
            chainedHandler = InterceptorChain.chainModelInterceptors(this.modelInterceptors, baseHandler);
            modelResponse = chainedHandler.call(modelRequest);
            return Map.of(StringUtils.hasLength((String)this.outputKey) ? this.outputKey : "messages", modelResponse.getMessage());
        }
        baseHandler = request -> {
            try {
                if (this.enableReasoningLog) {
                    String systemPrompt = request.getSystemMessage() != null ? request.getSystemMessage().getText() : "";
                    logger.info("[ThreadId {}] Agent {} reasoning round {} with system prompt: {}.", new Object[]{config.threadId().orElse("$default"), this.agentName, iterations.get(), systemPrompt});
                }
                ChatResponse response = this.buildChatClientRequestSpec(request, config).call().chatResponse();
                AssistantMessage responseMessage = new AssistantMessage("Empty response from model for unknown reason");
                if (response != null && response.getResult() != null) {
                    responseMessage = response.getResult().getOutput();
                }
                if (this.enableReasoningLog) {
                    logger.info("[ThreadId {}] Agent {} reasoning round {} returned: {}.", new Object[]{config.threadId().orElse("$default"), this.agentName, iterations.get(), responseMessage});
                }
                return ModelResponse.of(responseMessage, response);
            }
            catch (Exception e) {
                logger.error("Exception during invoking model call: ", (Throwable)e);
                return ModelResponse.of(new AssistantMessage("Exception: " + e.getMessage()));
            }
        };
        chainedHandler = InterceptorChain.chainModelInterceptors(this.modelInterceptors, baseHandler);
        if (this.enableReasoningLog) {
            logger.info("[ThreadId {}] Agent {} reasoning round {} model chain has started.", new Object[]{config.threadId().orElse("$default"), this.agentName, iterations.get()});
        }
        EmptyUsage tokenUsage = (modelResponse = chainedHandler.call(modelRequest)).getChatResponse() != null ? modelResponse.getChatResponse().getMetadata().getUsage() : new EmptyUsage();
        HashMap<String, Object> updatedState = new HashMap<String, Object>();
        updatedState.put("_TOKEN_USAGE_", tokenUsage);
        updatedState.put("messages", modelResponse.getMessage());
        if (!StringUtils.hasLength((String)this.outputKey)) return updatedState;
        updatedState.put(this.outputKey, modelResponse.getMessage());
        return updatedState;
    }

    public void setAdvisors(List<Advisor> advisors) {
        this.advisors = advisors;
    }

    private List<Message> appendSystemPromptIfNeeded(ModelRequest modelRequest) {
        long systemMessageCount;
        ArrayList<Message> messages = new ArrayList<Message>(modelRequest.getMessages());
        if (modelRequest.getSystemMessage() != null) {
            messages.add(0, (Message)modelRequest.getSystemMessage());
        }
        if ((systemMessageCount = messages.stream().filter(message -> message instanceof SystemMessage).count()) > 2L) {
            logger.warn("Detected {} SystemMessages in the message list. There should typically be only one SystemMessage. Multiple SystemMessages may cause unexpected behavior or model confusion.", (Object)systemMessageCount);
        }
        return messages;
    }

    @Nullable
    private ToolCallingChatOptions buildChatOptions(ChatOptions chatOptions, List<ToolCallback> toolCallbacks) {
        if (chatOptions == null && (toolCallbacks == null || toolCallbacks.isEmpty())) {
            return null;
        }
        if (chatOptions != null) {
            if (chatOptions instanceof ToolCallingChatOptions) {
                ToolCallingChatOptions builderToolCallingOptions = (ToolCallingChatOptions)chatOptions;
                ArrayList<ToolCallback> mergedToolCallbacks = new ArrayList<ToolCallback>(toolCallbacks);
                for (ToolCallback callback : builderToolCallingOptions.getToolCallbacks()) {
                    boolean exists = mergedToolCallbacks.stream().anyMatch(tc -> tc.getToolDefinition().name().equals(callback.getToolDefinition().name()));
                    if (exists) continue;
                    mergedToolCallbacks.add(callback);
                }
                builderToolCallingOptions.setToolCallbacks(mergedToolCallbacks);
                builderToolCallingOptions.setInternalToolExecutionEnabled(Boolean.valueOf(false));
                return builderToolCallingOptions;
            }
            logger.warn("The provided chatOptions is not of type ToolCallingChatOptions (actual type: {}). It will not take effect. Creating a new ToolCallingChatOptions with toolCallbacks instead.", (Object)chatOptions.getClass().getName());
        }
        return ToolCallingChatOptions.builder().toolCallbacks(toolCallbacks).internalToolExecutionEnabled(Boolean.valueOf(false)).build();
    }

    private String renderPromptTemplate(String prompt, Map<String, Object> params) {
        PromptTemplate.Builder builder = PromptTemplate.builder().template(prompt);
        if (this.templateRenderer != null) {
            builder.renderer(this.templateRenderer);
        }
        return builder.build().render(params);
    }

    public void augmentUserMessage(List<Message> messages, String outputSchema) {
        if (!StringUtils.hasText((String)outputSchema)) {
            return;
        }
        for (int i = messages.size() - 1; i >= 0; --i) {
            Message message = messages.get(i);
            if (message instanceof UserMessage) {
                UserMessage userMessage = (UserMessage)message;
                if (userMessage.getText().contains(outputSchema)) break;
                messages.set(i, (Message)userMessage.mutate().text(userMessage.getText() + System.lineSeparator() + outputSchema).build());
                break;
            }
            if (message instanceof AgentInstructionMessage) {
                AgentInstructionMessage templatedUserMessage = (AgentInstructionMessage)message;
                String newOutputSchema = outputSchema.replace("{", "\\{").replace("}", "\\}");
                if (templatedUserMessage.getText().contains(newOutputSchema)) break;
                messages.set(i, (Message)templatedUserMessage.mutate().text(templatedUserMessage.getText() + System.lineSeparator() + newOutputSchema).build());
                break;
            }
            if (i != 0) continue;
            messages.add((Message)new UserMessage(outputSchema));
        }
    }

    public void renderTemplatedUserMessage(List<Message> messages, Map<String, Object> params, Optional<Map<String, Object>> metadata) {
        HashMap<String, Object> processedParams = new HashMap<String, Object>();
        if (params != null) {
            for (Map.Entry<String, Object> entry : params.entrySet()) {
                String key = entry.getKey();
                Object value = entry.getValue();
                if ("messages".equals(key) || value instanceof List) continue;
                if (value instanceof Message) {
                    Message message = (Message)value;
                    processedParams.put(key, message.getText());
                    continue;
                }
                processedParams.put(key, value);
            }
        }
        for (int i = messages.size() - 1; i >= 0; --i) {
            AgentInstructionMessage instructionMessage;
            Message message = messages.get(i);
            if (!(message instanceof AgentInstructionMessage) || (instructionMessage = (AgentInstructionMessage)message).isRendered()) continue;
            AgentInstructionMessage newMessage = instructionMessage.mutate().text(this.renderPromptTemplate(instructionMessage.getText(), processedParams)).rendered(true).build();
            messages.set(i, (Message)newMessage);
            break;
        }
    }

    private List<ToolCallback> filterToolCallbacks(ModelRequest modelRequest) {
        List<String> requestedTools;
        ArrayList<ToolCallback> toolCallbacks = new ArrayList<ToolCallback>();
        if (modelRequest == null) {
            toolCallbacks.addAll(this.toolCallbacks);
            return toolCallbacks;
        }
        if (modelRequest.getOptions() != null && modelRequest.getOptions().getToolCallbacks() != null) {
            toolCallbacks.addAll(modelRequest.getOptions().getToolCallbacks());
        }
        if ((requestedTools = modelRequest.getTools()) == null || requestedTools.isEmpty()) {
            return toolCallbacks;
        }
        return new ArrayList<ToolCallback>(toolCallbacks.stream().filter(callback -> requestedTools.contains(callback.getToolDefinition().name())).toList());
    }

    private ChatClient.ChatClientRequestSpec buildChatClientRequestSpec(ModelRequest modelRequest, RunnableConfig config) {
        List<Message> messages = this.appendSystemPromptIfNeeded(modelRequest);
        List<ToolCallback> filteredToolCallbacks = this.filterToolCallbacks(modelRequest);
        if (!CollectionUtils.isEmpty(modelRequest.getDynamicToolCallbacks())) {
            filteredToolCallbacks.addAll(modelRequest.getDynamicToolCallbacks());
            config.context().put("_DYNAMIC_TOOL_CALLBACKS_", modelRequest.getDynamicToolCallbacks());
        }
        ChatClient.ChatClientRequestSpec promptSpec = this.chatClient.prompt().messages(messages).advisors(this.advisors);
        ToolCallingChatOptions requestOptions = modelRequest.getOptions();
        if (requestOptions != null) {
            requestOptions.setToolCallbacks(filteredToolCallbacks);
            requestOptions.setInternalToolExecutionEnabled(Boolean.valueOf(false));
            promptSpec.options((ChatOptions)requestOptions);
        } else if (promptSpec instanceof DefaultChatClient.DefaultChatClientRequestSpec) {
            DefaultChatClient.DefaultChatClientRequestSpec defaultChatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec)promptSpec;
            ChatOptions options = defaultChatClientRequestSpec.getChatOptions();
            if (options == null) {
                options = ToolCallingChatOptions.builder().toolCallbacks(filteredToolCallbacks).internalToolExecutionEnabled(Boolean.valueOf(false)).build();
                defaultChatClientRequestSpec.options(options);
            } else if (options instanceof ToolCallingChatOptions) {
                ToolCallingChatOptions toolCallingChatOptions = (ToolCallingChatOptions)options;
                toolCallingChatOptions.setToolCallbacks(filteredToolCallbacks);
                toolCallingChatOptions.setInternalToolExecutionEnabled(Boolean.valueOf(false));
            }
        } else if (!filteredToolCallbacks.isEmpty()) {
            promptSpec.tools(new Object[]{filteredToolCallbacks});
        }
        return promptSpec;
    }

    public String getName() {
        return "_AGENT_MODEL_";
    }

    public static class Builder {
        private String agentName;
        private String outputKey;
        private String outputSchema;
        private String systemPrompt;
        private TemplateRenderer templateRenderer;
        private ChatClient chatClient;
        private List<Advisor> advisors;
        private List<ToolCallback> toolCallbacks;
        private List<ModelInterceptor> modelInterceptors;
        private String instruction;
        private boolean enableReasoningLog;
        private ChatOptions chatOptions;

        public Builder agentName(String agentName) {
            this.agentName = agentName;
            return this;
        }

        public Builder outputKey(String outputKey) {
            this.outputKey = outputKey;
            return this;
        }

        public Builder outputSchema(String outputSchema) {
            this.outputSchema = outputSchema;
            return this;
        }

        public Builder systemPrompt(String systemPrompt) {
            this.systemPrompt = systemPrompt;
            return this;
        }

        public Builder templateRenderer(TemplateRenderer templateRenderer) {
            this.templateRenderer = templateRenderer;
            return this;
        }

        public Builder advisors(List<Advisor> advisors) {
            this.advisors = advisors;
            return this;
        }

        public Builder toolCallbacks(List<ToolCallback> toolCallbacks) {
            this.toolCallbacks = toolCallbacks;
            return this;
        }

        public Builder modelInterceptors(List<ModelInterceptor> modelInterceptors) {
            this.modelInterceptors = modelInterceptors;
            return this;
        }

        public Builder chatClient(ChatClient chatClient) {
            this.chatClient = chatClient;
            return this;
        }

        public Builder instruction(String instruction) {
            this.instruction = instruction;
            return this;
        }

        public Builder enableReasoningLog(boolean enableReasoningLog) {
            this.enableReasoningLog = enableReasoningLog;
            return this;
        }

        public Builder chatOptions(ChatOptions chatOptions) {
            this.chatOptions = chatOptions;
            return this;
        }

        public AgentLlmNode build() {
            return new AgentLlmNode(this);
        }
    }
}

