/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.agentic.supervisor;

import dev.langchain4j.agentic.internal.Context;
import dev.langchain4j.agentic.planner.Action;
import dev.langchain4j.agentic.planner.AgentInstance;
import dev.langchain4j.agentic.planner.ChatMemoryAccessProvider;
import dev.langchain4j.agentic.planner.InitPlanningContext;
import dev.langchain4j.agentic.planner.Planner;
import dev.langchain4j.agentic.planner.PlanningContext;
import dev.langchain4j.agentic.scope.AgenticScope;
import dev.langchain4j.agentic.scope.DefaultAgenticScope;
import dev.langchain4j.agentic.supervisor.AgentInvocation;
import dev.langchain4j.agentic.supervisor.PlannerAgent;
import dev.langchain4j.agentic.supervisor.ResponseAgent;
import dev.langchain4j.agentic.supervisor.ResponseScore;
import dev.langchain4j.agentic.supervisor.SupervisorContextStrategy;
import dev.langchain4j.agentic.supervisor.SupervisorResponseStrategy;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.memory.ChatMemoryAccess;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SupervisorPlanner
implements Planner,
ChatMemoryAccessProvider {
    private static final Logger LOG = LoggerFactory.getLogger(SupervisorPlanner.class);
    public static final String SUPERVISOR_CONTEXT_KEY = "supervisorContext";
    public static final String SUPERVISOR_CONTEXT_PREFIX = "Use the following supervisor context to better understand constraints, policies or preferences when creating the plan ";
    private final ChatModel chatModel;
    private final ChatMemoryProvider chatMemoryProvider;
    private final int maxAgentsInvocations;
    private int loopCount = 0;
    private ResponseAgent responseAgent;
    private final SupervisorContextStrategy contextStrategy;
    private final SupervisorResponseStrategy responseStrategy;
    private final Function<AgenticScope, String> requestGenerator;
    private final String outputKey;
    private final Function<AgenticScope, Object> output;
    private Map<String, AgentInstance> agents;
    private String agentsList;
    private String request;

    public SupervisorPlanner(ChatModel chatModel, ChatMemoryProvider chatMemoryProvider, int maxAgentsInvocations, SupervisorContextStrategy contextStrategy, SupervisorResponseStrategy responseStrategy, Function<AgenticScope, String> requestGenerator, String outputKey, Function<AgenticScope, Object> output) {
        this.chatModel = chatModel;
        this.chatMemoryProvider = chatMemoryProvider;
        this.maxAgentsInvocations = maxAgentsInvocations;
        this.contextStrategy = contextStrategy;
        this.responseStrategy = responseStrategy;
        this.requestGenerator = requestGenerator;
        this.outputKey = outputKey;
        this.output = output;
    }

    @Override
    public void init(InitPlanningContext initPlanningContext) {
        this.agents = initPlanningContext.subagents().stream().collect(Collectors.toMap(AgentInstance::agentId, Function.identity()));
        this.agentsList = initPlanningContext.subagents().stream().map(SupervisorPlanner::toCard).collect(Collectors.joining(", "));
        String string = this.request = this.requestGenerator != null ? this.requestGenerator.apply(initPlanningContext.agenticScope()) : initPlanningContext.agenticScope().readState("request", "");
        if (this.responseStrategy == SupervisorResponseStrategy.SCORED) {
            this.responseAgent = (ResponseAgent)AiServices.builder(ResponseAgent.class).chatModel(this.chatModel).build();
        }
    }

    @Override
    public Action nextAction(PlanningContext planningContext) {
        String lastResponse;
        String string = lastResponse = planningContext.previousAgentInvocation() == null ? "" : planningContext.previousAgentInvocation().output().toString();
        if (this.loopCount++ >= this.maxAgentsInvocations) {
            return this.doneAction(planningContext.agenticScope(), lastResponse, null);
        }
        return this.nextSubagent(planningContext.agenticScope(), lastResponse);
    }

    private static String toCard(AgentInstance agent) {
        List<String> agentArguments = agent.arguments().stream().filter(a -> !a.name().equals("@MemoryId")).map(a -> a.name() + ": " + a.rawType().getSimpleName()).toList();
        return "{'" + agent.agentId() + "', '" + agent.description() + "', " + String.valueOf(agentArguments) + "}";
    }

    private Action nextSubagent(AgenticScope agenticScope, String lastResponse) {
        String supervisorContext = agenticScope.hasState(SUPERVISOR_CONTEXT_KEY) ? "Use the following supervisor context to better understand constraints, policies or preferences when creating the plan '" + agenticScope.readState(SUPERVISOR_CONTEXT_KEY, "") + "'." : "";
        AgentInvocation agentInvocation = this.planner(agenticScope).plan(agenticScope.memoryId(), this.agentsList, this.request, lastResponse, supervisorContext);
        LOG.info("Agent Invocation: {}", (Object)agentInvocation);
        if (agentInvocation.getAgentName().equalsIgnoreCase("done")) {
            return this.doneAction(agenticScope, lastResponse, agentInvocation);
        }
        String agentName = agentInvocation.getAgentName();
        AgentInstance agent = this.agents.get(agentName);
        if (agent == null) {
            throw new IllegalStateException("No agent found with name: " + agentName);
        }
        agentInvocation.getArguments().forEach(agenticScope::writeState);
        return this.call(agent);
    }

    private Action doneAction(AgenticScope agenticScope, String lastResponse, AgentInvocation done) {
        Object result = this.result(agenticScope, this.request, lastResponse, done);
        if (this.outputKey != null) {
            agenticScope.writeState(this.outputKey, result);
        }
        return this.done(result);
    }

    private PlannerAgent planner(AgenticScope agenticScope) {
        return ((DefaultAgenticScope)agenticScope).getOrCreateAgent(this.agentId(), this::buildPlannerAgent);
    }

    private Object result(AgenticScope agenticScope, String request, String lastResponse, AgentInvocation done) {
        if (this.output != null) {
            return this.output.apply(agenticScope);
        }
        if (done == null || done.getArguments() == null || done.getArguments().get("response") == null) {
            return lastResponse;
        }
        String doneResponse = done.getArguments().get("response").toString();
        return switch (this.responseStrategy) {
            default -> throw new IncompatibleClassChangeError();
            case SupervisorResponseStrategy.LAST -> lastResponse;
            case SupervisorResponseStrategy.SUMMARY -> doneResponse;
            case SupervisorResponseStrategy.SCORED -> {
                ResponseScore score = this.responseAgent.scoreResponses(request, lastResponse, doneResponse);
                LOG.info("Response scores: {}", (Object)score);
                if (score.getScore2() > score.getScore1()) {
                    yield doneResponse;
                }
                yield lastResponse;
            }
        };
    }

    private PlannerAgent buildPlannerAgent(AgenticScope agenticScope) {
        AiServices builder = AiServices.builder(PlannerAgent.class).chatModel(this.chatModel);
        this.configureMemoryAndContext(agenticScope, (AiServices<PlannerAgent>)builder);
        return (PlannerAgent)builder.build();
    }

    private void configureMemoryAndContext(AgenticScope agenticScope, AiServices<PlannerAgent> builder) {
        if (this.chatMemoryProvider != null) {
            builder.chatMemoryProvider(this.chatMemoryProvider);
            if (this.contextStrategy != SupervisorContextStrategy.CHAT_MEMORY) {
                builder.chatRequestTransformer((BiFunction)new Context.Summarizer(agenticScope, this.chatModel, new String[0]));
            }
        } else {
            switch (this.contextStrategy) {
                case CHAT_MEMORY: {
                    builder.chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages((int)20));
                    break;
                }
                case SUMMARIZATION: {
                    builder.chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages((int)2)).chatRequestTransformer((BiFunction)new Context.Summarizer(agenticScope, this.chatModel, new String[0]));
                    break;
                }
                case CHAT_MEMORY_AND_SUMMARIZATION: {
                    builder.chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages((int)20)).chatRequestTransformer((BiFunction)new Context.Summarizer(agenticScope, this.chatModel, new String[0]));
                }
            }
        }
    }

    private String agentId() {
        return this.outputKey + "@Supervisor";
    }

    @Override
    public ChatMemoryAccess chatMemoryAccess(AgenticScope agenticScope) {
        return this.planner(agenticScope);
    }
}

