/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.hook.returndirect;

import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.agent.hook.HookPosition;
import com.alibaba.cloud.ai.graph.agent.hook.HookPositions;
import com.alibaba.cloud.ai.graph.agent.hook.JumpTo;
import com.alibaba.cloud.ai.graph.agent.hook.messages.AgentCommand;
import com.alibaba.cloud.ai.graph.agent.hook.messages.MessagesModelHook;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;

@HookPositions(value={HookPosition.BEFORE_MODEL})
public class ReturnDirectModelHook
extends MessagesModelHook {
    @Override
    public String getName() {
        return "finish_reason_check_messages_model_hook";
    }

    @Override
    public int getOrder() {
        return Integer.MIN_VALUE;
    }

    @Override
    public List<JumpTo> canJumpTo() {
        return List.of(JumpTo.end);
    }

    @Override
    public AgentCommand beforeModel(List<Message> previousMessages, RunnableConfig config) {
        Object finishReason;
        if (previousMessages.isEmpty()) {
            return new AgentCommand(previousMessages);
        }
        Message lastMessage = previousMessages.get(previousMessages.size() - 1);
        if (!(lastMessage instanceof ToolResponseMessage)) {
            return new AgentCommand(previousMessages);
        }
        ToolResponseMessage toolResponseMessage = (ToolResponseMessage)lastMessage;
        boolean returnDirect = false;
        Map metadata = toolResponseMessage.getMetadata();
        if (metadata.containsKey("finishReason") && "returnDirect".equals(finishReason = metadata.get("finishReason"))) {
            returnDirect = true;
        }
        if (returnDirect) {
            String generatedText = this.generateAssistantMessageText(toolResponseMessage);
            AssistantMessage newAssistantMessage = AssistantMessage.builder().content(generatedText).build();
            ArrayList<Message> newMessages = new ArrayList<Message>(previousMessages);
            newMessages.add((Message)newAssistantMessage);
            return new AgentCommand(JumpTo.end, newMessages);
        }
        return new AgentCommand(previousMessages);
    }

    private String generateAssistantMessageText(ToolResponseMessage toolResponseMessage) {
        List responses = toolResponseMessage.getResponses();
        if (responses.isEmpty()) {
            return "";
        }
        if (responses.size() == 1) {
            return ((ToolResponseMessage.ToolResponse)responses.get(0)).responseData();
        }
        StringBuilder jsonArray = new StringBuilder("[");
        for (int i = 0; i < responses.size(); ++i) {
            String responseData;
            if (i > 0) {
                jsonArray.append(",");
            }
            if ((responseData = ((ToolResponseMessage.ToolResponse)responses.get(i)).responseData()) == null) {
                jsonArray.append("null");
                continue;
            }
            String trimmed = responseData.trim();
            if (trimmed.startsWith("{") || trimmed.startsWith("[")) {
                jsonArray.append(responseData);
                continue;
            }
            jsonArray.append("\"").append(this.escapeJsonString(responseData)).append("\"");
        }
        jsonArray.append("]");
        return jsonArray.toString();
    }

    private String escapeJsonString(String str) {
        if (str == null) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        block9: for (char c : str.toCharArray()) {
            switch (c) {
                case '\"': {
                    sb.append("\\\"");
                    continue block9;
                }
                case '\\': {
                    sb.append("\\\\");
                    continue block9;
                }
                case '\b': {
                    sb.append("\\b");
                    continue block9;
                }
                case '\f': {
                    sb.append("\\f");
                    continue block9;
                }
                case '\n': {
                    sb.append("\\n");
                    continue block9;
                }
                case '\r': {
                    sb.append("\\r");
                    continue block9;
                }
                case '\t': {
                    sb.append("\\t");
                    continue block9;
                }
                default: {
                    if (c < ' ') {
                        sb.append(String.format("\\u%04x", c));
                        continue block9;
                    }
                    sb.append(c);
                }
            }
        }
        return sb.toString();
    }
}

