/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.flow.node;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.GraphResponse;
import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.action.NodeActionWithConfig;
import com.alibaba.cloud.ai.graph.agent.Agent;
import com.alibaba.cloud.ai.graph.agent.ReactAgent;
import com.alibaba.cloud.ai.graph.internal.node.ResumableSubGraphAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;

public class MainAgentNodeAction
implements NodeActionWithConfig {
    public static final Logger logger = LoggerFactory.getLogger(MainAgentNodeAction.class);
    private final ReactAgent mainAgent;
    private final List<Agent> subAgents;

    public MainAgentNodeAction(ReactAgent mainAgent, List<Agent> subAgents) {
        this.mainAgent = mainAgent;
        this.subAgents = subAgents != null ? subAgents : List.of();
    }

    private static List<String> parseJsonArrayOfStrings(String text) {
        if (!StringUtils.hasText((String)text)) {
            return List.of();
        }
        try {
            Object parsed = JsonParser.fromJson((String)text, List.class);
            if (parsed == null || !(parsed instanceof List)) {
                return null;
            }
            List list = (List)parsed;
            ArrayList<String> result = new ArrayList<String>();
            for (Object e : list) {
                if (e == null) continue;
                String s = String.valueOf(e).trim();
                if ("FINISH".equalsIgnoreCase(s)) {
                    result.add(s);
                    continue;
                }
                result.add(s);
            }
            return result;
        }
        catch (Exception e) {
            logger.warn("Failed to parse JSON array strings from returned sub agent list of Main Agent output text: {}", (Object)text, (Object)e);
            return null;
        }
    }

    public Map<String, Object> apply(OverAllState state, RunnableConfig config) throws Exception {
        RunnableConfig subGraphRunnableConfig = RunnableConfig.builder((RunnableConfig)config).threadId(config.threadId().map(threadId -> String.format("%s_%s", threadId, ResumableSubGraphAction.subGraphId((String)this.mainAgent.name()))).orElseGet(() -> ResumableSubGraphAction.subGraphId((String)this.mainAgent.name()))).nextNode(null).checkPointId(null).build();
        subGraphRunnableConfig.clearContext();
        logger.info("Invoking mainAgent '{}' compiled graph with threadId: {}", (Object)this.mainAgent.name(), (Object)subGraphRunnableConfig.threadId());
        CompiledGraph graph = this.mainAgent.getAndCompileGraph();
        Flux subGraphResult = graph.graphResponseStream(state, subGraphRunnableConfig);
        Flux<GraphResponse<NodeOutput>> graphResponseFlux = this.getGraphResponseFlux((Flux<GraphResponse<NodeOutput>>)subGraphResult);
        HashMap<String, Object> result = new HashMap<String, Object>();
        result.put("messages", graphResponseFlux);
        return result;
    }

    private Flux<GraphResponse<NodeOutput>> getGraphResponseFlux(Flux<GraphResponse<NodeOutput>> subGraphResult) {
        return subGraphResult.buffer(2, 1).flatMap(window -> {
            if (window.size() == 1) {
                return Flux.just(this.processLastResponse((GraphResponse<NodeOutput>)((GraphResponse)window.get(0))));
            }
            return Flux.just((Object)((GraphResponse)window.get(0)));
        }, 1);
    }

    private GraphResponse<NodeOutput> processLastResponse(GraphResponse<NodeOutput> lastResponse) {
        if (lastResponse == null || lastResponse.resultValue().isEmpty()) {
            return lastResponse;
        }
        Object resultValue = lastResponse.resultValue().get();
        if (!(resultValue instanceof Map)) {
            return lastResponse;
        }
        Map resultMap = (Map)resultValue;
        HashMap<String, Object> mainStateData = new HashMap<String, Object>(resultMap);
        RoutingExtract routing = this.extractRoutingFromMessages(mainStateData);
        if (routing == null) {
            return lastResponse;
        }
        HashMap<String, Object> newResultMap = new HashMap<String, Object>(mainStateData);
        newResultMap.put("supervisor_next", routing.routingValue());
        if (routing.routingMessage() != null && !"FINISH".equals(routing.routingValue())) {
            newResultMap.put("messages", routing.routingMessage());
        }
        return GraphResponse.done(newResultMap);
    }

    private RoutingExtract extractRoutingFromMessages(Map<String, Object> mainStateData) {
        boolean emptyOrFinish;
        boolean allValid;
        List messagesList;
        Object messagesObj = mainStateData.get("messages");
        if (messagesObj == null || !(messagesObj instanceof List) || (messagesList = (List)messagesObj).isEmpty()) {
            return null;
        }
        Object lastObj = messagesList.get(messagesList.size() - 1);
        if (!(lastObj instanceof AssistantMessage)) {
            return null;
        }
        AssistantMessage assistantMessage = (AssistantMessage)lastObj;
        String text = assistantMessage.getText();
        if (!StringUtils.hasText((String)text)) {
            logger.info("Empty text in last AssistantMessage, routing to FINISH");
            return new RoutingExtract(new ArrayList<String>(List.of("FINISH")), assistantMessage);
        }
        List<String> agentNames = MainAgentNodeAction.parseJsonArrayOfStrings(text.trim());
        if (agentNames == null) {
            logger.info("Failed to parse sub-agent names from last AssistantMessage text, routing to FINISH");
            return new RoutingExtract(new ArrayList<String>(List.of("FINISH")), assistantMessage);
        }
        List<String> validNames = agentNames.stream().filter(name -> this.subAgents.stream().anyMatch(a -> a.name().equals(name))).toList();
        boolean bl = allValid = validNames.size() == agentNames.size() && agentNames.stream().noneMatch("FINISH"::equalsIgnoreCase);
        if (allValid && !validNames.isEmpty()) {
            logger.info("MainAgentNodeAction: {} from last AssistantMessage = {}", (Object)"supervisor_next", validNames);
            return new RoutingExtract(new ArrayList<String>(validNames), assistantMessage);
        }
        boolean bl2 = emptyOrFinish = agentNames.isEmpty() || agentNames.stream().allMatch(s -> "FINISH".equalsIgnoreCase(s.trim()));
        if (emptyOrFinish) {
            logger.info("MainAgentNodeAction: {} = FINISH from last AssistantMessage", (Object)"supervisor_next");
            return new RoutingExtract(new ArrayList<String>(List.of("FINISH")), assistantMessage);
        }
        logger.info("No valid sub-agent names found in last AssistantMessage, routing to FINISH");
        return new RoutingExtract(new ArrayList<String>(List.of("FINISH")), assistantMessage);
    }

    private record RoutingExtract(Object routingValue, AssistantMessage routingMessage) {
    }
}

