/*
 * Decompiled with CFR 0.152.
 */
package org.apache.camel.component.langchain4j.tools;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.camel.Endpoint;
import org.apache.camel.Exchange;
import org.apache.camel.InvalidPayloadException;
import org.apache.camel.component.langchain4j.tools.LangChain4jToolsEndpoint;
import org.apache.camel.component.langchain4j.tools.TagsHelper;
import org.apache.camel.component.langchain4j.tools.spec.CamelToolExecutorCache;
import org.apache.camel.component.langchain4j.tools.spec.CamelToolSpecification;
import org.apache.camel.support.DefaultProducer;
import org.apache.camel.util.ObjectHelper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LangChain4jToolsProducer
extends DefaultProducer {
    private static final Logger LOG = LoggerFactory.getLogger(LangChain4jToolsProducer.class);
    private final LangChain4jToolsEndpoint endpoint;
    private ChatLanguageModel chatLanguageModel;
    private final ObjectMapper objectMapper = new ObjectMapper();

    public LangChain4jToolsProducer(LangChain4jToolsEndpoint endpoint) {
        super((Endpoint)endpoint);
        this.endpoint = endpoint;
    }

    public void process(Exchange exchange) throws Exception {
        this.processMultipleMessages(exchange);
    }

    private void processMultipleMessages(Exchange exchange) throws InvalidPayloadException {
        List messages = (List)exchange.getIn().getMandatoryBody(List.class);
        String response = this.toolsChat(messages, exchange);
        this.populateResponse(response, exchange);
    }

    protected void doStart() throws Exception {
        super.doStart();
        this.chatLanguageModel = this.endpoint.getConfiguration().getChatModel();
        ObjectHelper.notNull((Object)this.chatLanguageModel, (String)"chatLanguageModel");
    }

    private void populateResponse(String response, Exchange exchange) {
        exchange.getMessage().setBody((Object)response);
    }

    private boolean isMatch(String[] tags, Map.Entry<String, Set<CamelToolSpecification>> entry) {
        for (String tag : tags) {
            if (!entry.getKey().equals(tag)) continue;
            return true;
        }
        return false;
    }

    private String toolsChat(List<ChatMessage> chatMessages, Exchange exchange) {
        CamelToolExecutorCache toolCache = CamelToolExecutorCache.getInstance();
        ToolPair toolPair = this.computeCandidates(toolCache, exchange);
        if (toolPair == null) {
            return null;
        }
        int i = 0;
        while (true) {
            LOG.debug("Starting iteration {}", (Object)i);
            Response<AiMessage> response = this.chatWithLLM(chatMessages, toolPair, exchange);
            if (this.isDoneExecuting(response)) {
                return this.extractAiResponse(response);
            }
            this.invokeTools(chatMessages, exchange, response, toolPair);
            LOG.debug("Finished iteration {}", (Object)i);
            ++i;
        }
    }

    private boolean isDoneExecuting(Response<AiMessage> response) {
        if (!((AiMessage)response.content()).hasToolExecutionRequests()) {
            LOG.info("Finished executing tools because of there are no more execution requests");
            return true;
        }
        if (response.finishReason() != null) {
            LOG.info("Finished executing tools because of {}", (Object)response.finishReason());
            if (response.finishReason() == FinishReason.STOP) {
                return true;
            }
        }
        return false;
    }

    private void invokeTools(List<ChatMessage> chatMessages, Exchange exchange, Response<AiMessage> response, ToolPair toolPair) {
        int i = 0;
        List toolExecutionRequests = ((AiMessage)response.content()).toolExecutionRequests();
        for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) {
            String toolName = toolExecutionRequest.name();
            LOG.info("Invoking tool {} ({}) of {}", new Object[]{i, toolName, toolExecutionRequests.size()});
            CamelToolSpecification camelToolSpecification = toolPair.callableTools().stream().filter(c -> c.getToolSpecification().name().equals(toolName)).findFirst().get();
            try {
                JsonNode jsonNode = (JsonNode)this.objectMapper.readValue(toolExecutionRequest.arguments(), JsonNode.class);
                jsonNode.fieldNames().forEachRemaining(name -> exchange.getMessage().setHeader(name, (Object)jsonNode.get(name)));
                camelToolSpecification.getConsumer().getProcessor().process(exchange);
                ++i;
            }
            catch (Exception e) {
                exchange.setException((Throwable)e);
            }
            chatMessages.add((ChatMessage)new ToolExecutionResultMessage(toolExecutionRequest.id(), toolExecutionRequest.name(), (String)exchange.getIn().getBody(String.class)));
        }
    }

    private Response<AiMessage> chatWithLLM(List<ChatMessage> chatMessages, ToolPair toolPair, Exchange exchange) {
        Response response = this.chatLanguageModel.generate(chatMessages, toolPair.toolSpecifications());
        if (!((AiMessage)response.content()).hasToolExecutionRequests()) {
            exchange.getMessage().setHeader("LangChain4jToolsNoToolsCalled", (Object)Boolean.TRUE);
            return response;
        }
        chatMessages.add((ChatMessage)response.content());
        return response;
    }

    private ToolPair computeCandidates(CamelToolExecutorCache toolCache, Exchange exchange) {
        ArrayList<ToolSpecification> toolSpecifications = new ArrayList<ToolSpecification>();
        ArrayList<CamelToolSpecification> callableTools = new ArrayList<CamelToolSpecification>();
        Map<String, Set<CamelToolSpecification>> tools = toolCache.getTools();
        String[] tags = TagsHelper.splitTags(this.endpoint.getTags());
        for (Map.Entry<String, Set<CamelToolSpecification>> entry : tools.entrySet()) {
            if (!this.isMatch(tags, entry)) continue;
            List callablesForTag = entry.getValue().stream().toList();
            callableTools.addAll(callablesForTag);
            List<ToolSpecification> toolsForTag = entry.getValue().stream().map(CamelToolSpecification::getToolSpecification).toList();
            toolSpecifications.addAll(toolsForTag);
        }
        if (toolSpecifications.isEmpty()) {
            exchange.getMessage().setHeader("LangChain4jToolsNoToolsCalled", (Object)Boolean.TRUE);
            return null;
        }
        return new ToolPair(toolSpecifications, callableTools);
    }

    private String extractAiResponse(Response<AiMessage> response) {
        AiMessage message = (AiMessage)response.content();
        return message == null ? null : message.text();
    }

    private record ToolPair(List<ToolSpecification> toolSpecifications, List<CamelToolSpecification> callableTools) {
    }
}

