/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.hook;

import com.alibaba.cloud.ai.graph.KeyStrategy;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.action.InterruptableAction;
import com.alibaba.cloud.ai.graph.action.InterruptionMetadata;
import com.alibaba.cloud.ai.graph.agent.hook.HookPosition;
import com.alibaba.cloud.ai.graph.agent.hook.HookPositions;
import com.alibaba.cloud.ai.graph.agent.hook.JumpTo;
import com.alibaba.cloud.ai.graph.agent.hook.ModelHook;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;

@HookPositions(value={HookPosition.BEFORE_MODEL})
public class InterruptionHook
extends ModelHook
implements AsyncNodeActionWithConfig,
InterruptableAction {
    private static final Logger log = LoggerFactory.getLogger(InterruptionHook.class);
    public static final String INTERRUPTION_FEEDBACK_KEY = "INTERRUPTION_FEEDBACK";
    public static final String INTERRUPTION_NODE_NAME = "INTERRUPTION";

    private InterruptionHook(Builder builder) {
    }

    public static Builder builder() {
        return new Builder();
    }

    public CompletableFuture<Map<String, Object>> apply(OverAllState state, RunnableConfig config) {
        AssistantMessage assistantMessage;
        List messages;
        Object item2;
        String threadId = config.threadId().orElse("$default");
        Map<String, Object> agentThreadState = this.getAgent().getThreadState(threadId);
        if (agentThreadState == null) {
            log.debug("No agent thread state found for threadId {}, continuing normal execution.", (Object)threadId);
            return CompletableFuture.completedFuture(Map.of());
        }
        Object feedback = agentThreadState.remove(INTERRUPTION_FEEDBACK_KEY);
        if (feedback == null) {
            log.debug("No interruption feedback found in state, continue reasoning with no updates.");
            return CompletableFuture.completedFuture(Map.of());
        }
        ArrayList<Object> feedbackMessages = new ArrayList<Object>();
        if (feedback instanceof List) {
            List feedbackList = (List)feedback;
            for (Object item2 : feedbackList) {
                if (item2 instanceof Message) {
                    feedbackMessages.add((Message)item2);
                    continue;
                }
                log.warn("Feedback list contains non-Message item, ignoring. Type: {}", (Object)(item2 != null ? item2.getClass().getName() : "null"));
            }
            if (feedbackMessages.isEmpty()) {
                log.warn("Feedback list is empty or contains no valid Message instances, stop and wait for more input.");
                return CompletableFuture.completedFuture(Map.of());
            }
        } else if (feedback instanceof UserMessage) {
            feedbackMessages.add((UserMessage)feedback);
        } else if (feedback instanceof String) {
            feedbackMessages.add(new UserMessage((String)feedback));
        } else {
            log.warn("Interruption feedback is neither List<Message>, UserMessage nor String, stop and wait for more input. Type: {}", (Object)feedback.getClass().getName());
            return CompletableFuture.completedFuture(Map.of());
        }
        if (!(messages = (List)state.value("messages").orElse(new ArrayList())).isEmpty() && (item2 = messages.get(messages.size() - 1)) instanceof AssistantMessage && (assistantMessage = (AssistantMessage)item2).hasToolCalls()) {
            log.info("Last message is an AssistantMessage with tool calls, not adding interruption feedback to messages list.");
            return CompletableFuture.completedFuture(Map.of());
        }
        ArrayList newMessages = new ArrayList(feedbackMessages);
        HashMap updates = new HashMap();
        updates.put("messages", newMessages);
        log.debug("Added {} interruption feedback message(s) to messages list and removed INTERRUPTION_FEEDBACK_KEY from state.", (Object)feedbackMessages.size());
        return CompletableFuture.completedFuture(updates);
    }

    public Optional<InterruptionMetadata> interrupt(String nodeId, OverAllState state, RunnableConfig config) {
        String threadId = config.threadId().orElse("$default");
        Map<String, Object> agentThreadState = this.getAgent().getThreadState(threadId);
        if (agentThreadState == null) {
            log.debug("No agent thread state found for threadId {}, continuing normal execution.", (Object)threadId);
            return Optional.empty();
        }
        Object feedbackValue = agentThreadState.get(INTERRUPTION_FEEDBACK_KEY);
        if (feedbackValue == null) {
            log.debug("No INTERRUPTION_FEEDBACK_KEY in state, continuing normal execution.");
            return Optional.empty();
        }
        if (feedbackValue instanceof List) {
            List feedbackList = (List)feedbackValue;
            if (feedbackList.isEmpty()) {
                InterruptionMetadata interruptionMetadata = ((InterruptionMetadata.Builder)InterruptionMetadata.builder((String)nodeId, (OverAllState)state).addMetadata("interruption_requested", (Object)true)).build();
                log.debug("INTERRUPTION_FEEDBACK_KEY is empty list, returning InterruptionMetadata.");
                return Optional.of(interruptionMetadata);
            }
            log.debug("INTERRUPTION_FEEDBACK_KEY has non-empty list, continuing normal execution.");
            return Optional.empty();
        }
        log.debug("INTERRUPTION_FEEDBACK_KEY is not a list, continuing normal execution.");
        return Optional.empty();
    }

    @Override
    public String getName() {
        return INTERRUPTION_NODE_NAME;
    }

    @Override
    public List<JumpTo> canJumpTo() {
        return List.of();
    }

    @Override
    public Map<String, KeyStrategy> getKeyStrategys() {
        return Map.of(INTERRUPTION_FEEDBACK_KEY, new ReplaceStrategy());
    }

    public static class Builder {
        public InterruptionHook build() {
            return new InterruptionHook(this);
        }
    }
}

