/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.flow.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.action.MultiCommand;
import com.alibaba.cloud.ai.graph.action.MultiCommandAction;
import com.alibaba.cloud.ai.graph.agent.Agent;
import com.alibaba.cloud.ai.graph.agent.flow.agent.LlmRoutingAgent;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.util.StringUtils;

public class RoutingNode
implements MultiCommandAction {
    private static final Logger logger = LoggerFactory.getLogger(RoutingNode.class);
    private static final int DEFAULT_MAX_RETRIES = 2;
    private final ChatClient chatClient;
    private final BeanOutputConverter<RoutingDecision> outputConverter;
    private final Agent rootAgent;
    private final List<Agent> subAgents;

    public RoutingNode(ChatModel chatModel, Agent rootAgent, List<Agent> subAgents) {
        LlmRoutingAgent llmRoutingAgent;
        this.rootAgent = rootAgent;
        this.subAgents = subAgents;
        StringBuilder sb = new StringBuilder();
        if (rootAgent instanceof LlmRoutingAgent && StringUtils.hasLength((String)(llmRoutingAgent = (LlmRoutingAgent)rootAgent).getSystemPrompt())) {
            sb.append("You are responsible for task routing in a graph-based AI system.\n");
            sb.append("The instruction that you should follow to finish this task is:\n\n ");
            sb.append(llmRoutingAgent.getSystemPrompt());
        } else {
            sb.append("You are responsible for task routing in a graph-based AI system.\n");
            sb.append("\n\n");
            sb.append("You have access to some specialized agents that can handle this task. You can delegate the task to ONE or MULTIPLE agents for parallel execution.\n");
            sb.append("The available agents and their capabilities are listed below:\n");
            for (Agent agent : subAgents) {
                sb.append("- ").append(agent.name()).append(": ").append(agent.description()).append("\n");
            }
            sb.append("\n");
            sb.append("Return a list of agent names from the list above. You can return one or multiple agents.\n");
            sb.append("If multiple agents are returned, they will execute in parallel.\n");
            sb.append("Available names: ");
            sb.append(String.join((CharSequence)", ", subAgents.stream().map(Agent::name).toList()));
            sb.append("\n\n");
            sb.append("Example for single agent: [\"prose_writer_agent\"]\n");
            sb.append("Example for multiple agents: [\"prose_writer_agent\", \"code_reviewer_agent\"]");
        }
        this.outputConverter = new BeanOutputConverter(RoutingDecision.class);
        sb.append("\n\n");
        sb.append(this.outputConverter.getFormat());
        this.chatClient = ChatClient.builder((ChatModel)chatModel).defaultSystem(sb.toString()).build();
    }

    public MultiCommand apply(OverAllState state, RunnableConfig config) throws Exception {
        List<Message> messages = state.value("messages").orElse(List.of());
        List<Message> messagesWithInstruction = this.prepareMessagesWithInstruction(messages);
        List<String> decisionValues = this.getDecisionWithRetry(messagesWithInstruction, 2);
        List invalidAgents = decisionValues.stream().filter(agentName -> this.subAgents.stream().noneMatch(agent -> agent.name().equals(agentName))).collect(Collectors.toList());
        if (invalidAgents.isEmpty() && !decisionValues.isEmpty()) {
            if (decisionValues.size() == 1) {
                logger.info("RoutingAgent {} routed to single sub-agent {}.", (Object)this.rootAgent.name(), (Object)decisionValues.get(0));
            } else {
                logger.info("RoutingAgent {} routed to {} sub-agents in parallel: {}.", new Object[]{this.rootAgent.name(), decisionValues.size(), String.join((CharSequence)", ", decisionValues)});
            }
            return new MultiCommand(decisionValues, Map.of());
        }
        logger.error("RoutingAgent {} failed to get valid decision after {} retries. Invalid agents: {}.", new Object[]{this.rootAgent.name(), 2, invalidAgents});
        throw new IllegalStateException("RoutingAgent " + this.rootAgent.name() + " failed to get valid decision after retries. Invalid agents: " + invalidAgents + ".");
    }

    private List<Message> prepareMessagesWithInstruction(List<Message> messages) {
        ArrayList<Message> messagesWithInstruction = new ArrayList<Message>(messages);
        Agent agent = this.rootAgent;
        if (agent instanceof LlmRoutingAgent) {
            LlmRoutingAgent llmRoutingAgent = (LlmRoutingAgent)agent;
            String instruction = llmRoutingAgent.getInstruction();
            if (StringUtils.hasLength((String)instruction)) {
                messagesWithInstruction.add((Message)new UserMessage(instruction));
            } else {
                messagesWithInstruction.add((Message)new UserMessage("Based on the chat history and current task progress, please decide the next agent to delegate the task to."));
            }
        } else {
            messagesWithInstruction.add((Message)new UserMessage("Based on the chat history and current task progress, please decide the next agent to delegate the task to."));
        }
        return messagesWithInstruction;
    }

    private List<String> getDecisionWithRetry(List<Message> messages, int maxRetries) throws Exception {
        List<Object> lastInvalidDecision = null;
        for (int attempt = 0; attempt <= maxRetries; ++attempt) {
            try {
                RoutingDecision decision;
                if (attempt == 0) {
                    decision = (RoutingDecision)this.chatClient.prompt().messages(messages).call().entity(this.outputConverter);
                } else {
                    String errorFeedback = String.format("Previous attempt returned invalid agent names: %s. Please choose from the available agents: %s.", lastInvalidDecision != null ? String.join((CharSequence)", ", lastInvalidDecision) : "[]", String.join((CharSequence)", ", this.subAgents.stream().map(Agent::name).toList()));
                    logger.warn("RoutingAgent {} retry attempt {}/{}. Previous invalid decision: {}", new Object[]{this.rootAgent.name(), attempt, maxRetries, lastInvalidDecision});
                    ArrayList<Object> messagesWithFeedback = new ArrayList<Object>();
                    boolean systemMessageFound = false;
                    for (Message msg : messages) {
                        if (msg instanceof SystemMessage && !systemMessageFound) {
                            String enhancedContent = msg.getText() + "\n\n" + errorFeedback;
                            messagesWithFeedback.add(new SystemMessage(enhancedContent));
                            systemMessageFound = true;
                            continue;
                        }
                        messagesWithFeedback.add(msg);
                    }
                    if (!systemMessageFound) {
                        messagesWithFeedback.add(new UserMessage(errorFeedback));
                    }
                    decision = (RoutingDecision)this.chatClient.prompt().messages(messagesWithFeedback).call().entity(this.outputConverter);
                }
                List<String> decisionValues = decision.getAgents();
                if (decisionValues != null && !decisionValues.isEmpty()) {
                    List invalidAgents = decisionValues.stream().filter(agentName -> this.subAgents.stream().noneMatch(agent -> agent.name().equals(agentName))).collect(Collectors.toList());
                    if (invalidAgents.isEmpty()) {
                        if (attempt > 0) {
                            logger.info("RoutingAgent {} succeeded on retry attempt {}. Routed to sub-agents: {}", new Object[]{this.rootAgent.name(), attempt, String.join((CharSequence)", ", decisionValues)});
                        }
                        return decisionValues;
                    }
                    lastInvalidDecision = decisionValues;
                    logger.warn("RoutingAgent {} attempt {}/{} returned invalid agent names: {}", new Object[]{this.rootAgent.name(), attempt, maxRetries, invalidAgents});
                    continue;
                }
                lastInvalidDecision = Collections.emptyList();
                logger.warn("RoutingAgent {} attempt {}/{} returned empty agent list", new Object[]{this.rootAgent.name(), attempt, maxRetries});
                continue;
            }
            catch (Exception e) {
                if (attempt == maxRetries) {
                    logger.error("RoutingAgent {} failed on final attempt {}/{}", new Object[]{this.rootAgent.name(), attempt, maxRetries, e});
                    throw e;
                }
                logger.warn("RoutingAgent {} attempt {}/{} encountered an error, will retry", new Object[]{this.rootAgent.name(), attempt, maxRetries, e});
            }
        }
        throw new IllegalStateException(String.format("Failed to get valid decision after %d retries. Last invalid decision: %s", maxRetries, lastInvalidDecision));
    }

    public record RoutingDecision(String agent, List<String> agents) {
        public RoutingDecision(String agent, List<String> agents) {
            this.agents = agents != null && !agents.isEmpty() ? agents : (agent != null && !agent.isEmpty() ? Collections.singletonList(agent) : Collections.emptyList());
            this.agent = agent != null && !agent.isEmpty() ? agent : (!this.agents.isEmpty() ? this.agents.get(0) : null);
        }

        public RoutingDecision(String agent) {
            this(agent, agent != null ? Collections.singletonList(agent) : Collections.emptyList());
        }

        public RoutingDecision(List<String> agents) {
            this(agents != null && !agents.isEmpty() ? agents.get(0) : null, agents != null ? agents : Collections.emptyList());
        }

        public List<String> getAgents() {
            if (this.agents != null && !this.agents.isEmpty()) {
                return this.agents;
            }
            if (this.agent != null && !this.agent.isEmpty()) {
                return Collections.singletonList(this.agent);
            }
            return Collections.emptyList();
        }
    }
}

