/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.flow.strategy;

import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.agent.Agent;
import com.alibaba.cloud.ai.graph.agent.Prioritized;
import com.alibaba.cloud.ai.graph.agent.flow.builder.FlowGraphBuilder;
import com.alibaba.cloud.ai.graph.agent.flow.node.TransparentNode;
import com.alibaba.cloud.ai.graph.agent.flow.strategy.FlowGraphBuildingStrategy;
import com.alibaba.cloud.ai.graph.agent.hook.AgentHook;
import com.alibaba.cloud.ai.graph.agent.hook.Hook;
import com.alibaba.cloud.ai.graph.agent.hook.HookPosition;
import com.alibaba.cloud.ai.graph.agent.hook.InterruptionHook;
import com.alibaba.cloud.ai.graph.agent.hook.ModelHook;
import com.alibaba.cloud.ai.graph.agent.hook.messages.MessagesAgentHook;
import com.alibaba.cloud.ai.graph.agent.hook.messages.MessagesModelHook;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;

public abstract class AbstractFlowGraphBuildingStrategy
implements FlowGraphBuildingStrategy {
    protected StateGraph graph;
    protected Agent rootAgent;
    protected List<Hook> beforeAgentHooks;
    protected List<Hook> afterAgentHooks;
    protected List<Hook> beforeModelHooks;
    protected List<Hook> afterModelHooks;
    protected String entryNode;
    protected String exitNode;

    @Override
    public final StateGraph buildGraph(FlowGraphBuilder.FlowGraphConfig config) throws GraphStateException {
        this.graph = config.getStateSerializer() != null ? new StateGraph(config.getName(), config.getKeyStrategyFactory(), config.getStateSerializer()) : new StateGraph(config.getName(), config.getKeyStrategyFactory());
        this.rootAgent = config.getRootAgent();
        this.beforeAgentHooks = AbstractFlowGraphBuildingStrategy.filterHooksByPosition(config.getHooks(), HookPosition.BEFORE_AGENT);
        this.afterAgentHooks = AbstractFlowGraphBuildingStrategy.filterHooksByPosition(config.getHooks(), HookPosition.AFTER_AGENT);
        this.beforeModelHooks = AbstractFlowGraphBuildingStrategy.filterHooksByPosition(config.getHooks(), HookPosition.BEFORE_MODEL);
        this.afterModelHooks = AbstractFlowGraphBuildingStrategy.filterHooksByPosition(config.getHooks(), HookPosition.AFTER_MODEL);
        this.addBeforeAgentHookNodesToGraph(this.graph, this.beforeAgentHooks);
        this.addAfterAgentHookNodesToGraph(this.graph, this.afterAgentHooks);
        this.addBeforeModelHookNodes(this.graph, this.beforeModelHooks);
        this.addAfterModelHookNodes(this.graph, this.afterModelHooks);
        this.entryNode = this.determineEntryNode(this.getRootAgent(), this.beforeAgentHooks, this.beforeModelHooks);
        this.exitNode = this.determineExitNode(this.afterAgentHooks);
        this.graph.addEdge("__START__", this.entryNode);
        this.buildCoreGraph(config);
        this.connectBeforeModelHooks();
        this.connectAfterModelHooks();
        this.connectBeforeAgentHooks();
        this.connectAfterAgentHooks();
        return this.graph;
    }

    protected Agent getRootAgent() {
        return this.rootAgent;
    }

    protected void buildCoreGraph(FlowGraphBuilder.FlowGraphConfig config) throws GraphStateException {
        this.graph.addNode(this.getRootAgent().name(), AsyncNodeAction.node_async((NodeAction)new TransparentNode()));
    }

    protected void connectBeforeModelHooks() throws GraphStateException {
        if (!this.beforeModelHooks.isEmpty()) {
            this.connectBeforeModelHookEdges(this.graph, this.getRootAgent().name(), this.beforeModelHooks);
        }
    }

    protected void connectAfterModelHooks() throws GraphStateException {
        if (!this.afterModelHooks.isEmpty()) {
            this.connectAfterModelHookEdges(this.graph, this.getRootAgent().name(), this.afterModelHooks);
        }
    }

    protected void connectBeforeAgentHooks() throws GraphStateException {
        if (!this.beforeAgentHooks.isEmpty()) {
            String nextNode = !this.beforeModelHooks.isEmpty() ? Hook.getFullHookName(this.beforeModelHooks.get(0)) + ".beforeModel" : this.getRootAgent().name();
            this.chainBeforeAgentHooks(this.graph, this.beforeAgentHooks, nextNode);
        }
    }

    protected void connectAfterAgentHooks() throws GraphStateException {
        if (!this.afterAgentHooks.isEmpty()) {
            this.chainAfterAgentHooks(this.graph, this.afterAgentHooks);
        }
    }

    protected static List<Hook> filterHooksByPosition(List<? extends Hook> hooks, HookPosition position) {
        if (hooks == null || hooks.isEmpty()) {
            return new ArrayList<Hook>();
        }
        List filtered = hooks.stream().filter(hook -> {
            HookPosition[] positions = hook.getHookPositions();
            return Arrays.asList(positions).contains((Object)position);
        }).collect(Collectors.toList());
        ArrayList<Hook> prioritizedHooks = new ArrayList<Hook>();
        ArrayList<Hook> nonPrioritizedHooks = new ArrayList<Hook>();
        for (Hook hook2 : filtered) {
            if (hook2 instanceof Prioritized) {
                prioritizedHooks.add(hook2);
                continue;
            }
            nonPrioritizedHooks.add(hook2);
        }
        prioritizedHooks.sort(Comparator.comparingInt(h -> h.getOrder()));
        ArrayList<Hook> result = new ArrayList<Hook>(prioritizedHooks);
        result.addAll(nonPrioritizedHooks);
        return result;
    }

    private void addBeforeModelHookNodes(StateGraph graph, List<Hook> beforeModelHooks) throws GraphStateException {
        for (Hook hook : beforeModelHooks) {
            if (hook instanceof ModelHook) {
                ModelHook modelHook = (ModelHook)hook;
                if (hook instanceof InterruptionHook) {
                    InterruptionHook interruptionHook = (InterruptionHook)hook;
                    graph.addNode(Hook.getFullHookName(hook) + ".beforeModel", (AsyncNodeActionWithConfig)interruptionHook);
                    continue;
                }
                graph.addNode(Hook.getFullHookName(hook) + ".beforeModel", modelHook::beforeModel);
                continue;
            }
            if (!(hook instanceof MessagesModelHook)) continue;
            MessagesModelHook messagesModelHook = (MessagesModelHook)hook;
            graph.addNode(Hook.getFullHookName(hook) + ".beforeModel", (AsyncNodeActionWithConfig)MessagesModelHook.beforeModelAction(messagesModelHook));
        }
    }

    private void addAfterModelHookNodes(StateGraph graph, List<Hook> afterModelHooks) throws GraphStateException {
        for (Hook hook : afterModelHooks) {
            if (hook instanceof ModelHook) {
                ModelHook modelHook = (ModelHook)hook;
                if (hook instanceof InterruptionHook) {
                    InterruptionHook interruptionHook = (InterruptionHook)hook;
                    graph.addNode(Hook.getFullHookName(hook) + ".afterModel", (AsyncNodeActionWithConfig)interruptionHook);
                    continue;
                }
                graph.addNode(Hook.getFullHookName(hook) + ".afterModel", modelHook::afterModel);
                continue;
            }
            if (!(hook instanceof MessagesModelHook)) continue;
            MessagesModelHook messagesModelHook = (MessagesModelHook)hook;
            graph.addNode(Hook.getFullHookName(hook) + ".afterModel", (AsyncNodeActionWithConfig)MessagesModelHook.afterModelAction(messagesModelHook));
        }
    }

    protected void addBeforeAgentHookNodesToGraph(StateGraph graph, List<Hook> beforeAgentHooks) throws GraphStateException {
        for (Hook hook : beforeAgentHooks) {
            String hookNodeName = hook.getName() + ".before";
            if (hook instanceof AgentHook) {
                AgentHook agentHook = (AgentHook)hook;
                graph.addNode(hookNodeName, agentHook::beforeAgent);
                continue;
            }
            if (!(hook instanceof MessagesAgentHook)) continue;
            MessagesAgentHook messagesAgentHook = (MessagesAgentHook)hook;
            graph.addNode(hookNodeName, (AsyncNodeActionWithConfig)MessagesAgentHook.beforeAgentAction(messagesAgentHook));
        }
    }

    protected void addAfterAgentHookNodesToGraph(StateGraph graph, List<Hook> afterAgentHooks) throws GraphStateException {
        for (Hook hook : afterAgentHooks) {
            String hookNodeName = hook.getName() + ".after";
            if (hook instanceof AgentHook) {
                AgentHook agentHook = (AgentHook)hook;
                graph.addNode(hookNodeName, agentHook::afterAgent);
                continue;
            }
            if (!(hook instanceof MessagesAgentHook)) continue;
            MessagesAgentHook messagesAgentHook = (MessagesAgentHook)hook;
            graph.addNode(hookNodeName, (AsyncNodeActionWithConfig)MessagesAgentHook.afterAgentAction(messagesAgentHook));
        }
    }

    protected String connectBeforeModelHookEdges(StateGraph graph, String defaultFirstNode, List<Hook> beforeModelHooks) throws GraphStateException {
        if (beforeModelHooks.isEmpty()) {
            return defaultFirstNode;
        }
        String firstNodeName = Hook.getFullHookName(beforeModelHooks.get(0)) + ".beforeModel";
        String prevHookNodeName = null;
        for (Hook hook : beforeModelHooks) {
            String hookNodeName = Hook.getFullHookName(hook) + ".beforeModel";
            if (prevHookNodeName != null) {
                graph.addEdge(prevHookNodeName, hookNodeName);
            }
            prevHookNodeName = hookNodeName;
        }
        graph.addEdge(prevHookNodeName, defaultFirstNode);
        return firstNodeName;
    }

    protected String connectAfterModelHookEdges(StateGraph graph, String sourceNode, List<Hook> afterModelHooks) throws GraphStateException {
        if (afterModelHooks.isEmpty()) {
            return sourceNode;
        }
        String prevHookNodeName = null;
        String lastHookNodeName = null;
        for (Hook hook : afterModelHooks) {
            String hookNodeName = Hook.getFullHookName(hook) + ".afterModel";
            if (prevHookNodeName == null) {
                graph.addEdge(sourceNode, hookNodeName);
            } else {
                graph.addEdge(prevHookNodeName, hookNodeName);
            }
            prevHookNodeName = hookNodeName;
            lastHookNodeName = hookNodeName;
        }
        return lastHookNodeName;
    }

    protected void chainBeforeAgentHooks(StateGraph graph, List<Hook> beforeAgentHooks, String nextNode) throws GraphStateException {
        for (int i = 0; i < beforeAgentHooks.size(); ++i) {
            Hook hook = beforeAgentHooks.get(i);
            String hookNodeName = hook.getName() + ".before";
            if (i < beforeAgentHooks.size() - 1) {
                String nextHookName = beforeAgentHooks.get(i + 1).getName() + ".before";
                graph.addEdge(hookNodeName, nextHookName);
                continue;
            }
            graph.addEdge(hookNodeName, nextNode);
        }
    }

    protected void chainAfterAgentHooks(StateGraph graph, List<Hook> afterAgentHooks) throws GraphStateException {
        if (afterAgentHooks.isEmpty()) {
            return;
        }
        String firstHookName = afterAgentHooks.get(0).getName() + ".after";
        graph.addEdge(firstHookName, "__END__");
        for (int i = afterAgentHooks.size() - 1; i > 0; --i) {
            Hook currentHook = afterAgentHooks.get(i);
            Hook prevHook = afterAgentHooks.get(i - 1);
            String currentHookName = currentHook.getName() + ".after";
            String prevHookName = prevHook.getName() + ".after";
            graph.addEdge(currentHookName, prevHookName);
        }
    }

    protected String determineExitNode(List<Hook> afterAgentHooks) {
        if (!afterAgentHooks.isEmpty()) {
            return afterAgentHooks.get(afterAgentHooks.size() - 1).getName() + ".after";
        }
        return "__END__";
    }

    protected String determineEntryNode(Agent rootAgent, List<Hook> beforeAgentHooks, List<Hook> beforeModelHooks) {
        if (!beforeAgentHooks.isEmpty()) {
            return beforeAgentHooks.get(0).getName() + ".before";
        }
        if (!beforeModelHooks.isEmpty()) {
            return Hook.getFullHookName(beforeModelHooks.get(0)) + ".beforeModel";
        }
        return rootAgent.name();
    }
}

