/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.mcp.client;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.mcp.client.McpClient;
import dev.langchain4j.mcp.client.ResourceRef;
import dev.langchain4j.mcp.client.ResourceResponse;
import dev.langchain4j.mcp.client.ResourceTemplateRef;
import dev.langchain4j.mcp.client.ResourcesHelper;
import dev.langchain4j.mcp.client.ToolExecutionHelper;
import dev.langchain4j.mcp.client.ToolSpecificationHelper;
import dev.langchain4j.mcp.client.logging.DefaultMcpLogMessageHandler;
import dev.langchain4j.mcp.client.logging.McpLogMessageHandler;
import dev.langchain4j.mcp.client.protocol.CancellationNotification;
import dev.langchain4j.mcp.client.protocol.InitializeParams;
import dev.langchain4j.mcp.client.protocol.McpCallToolRequest;
import dev.langchain4j.mcp.client.protocol.McpInitializeRequest;
import dev.langchain4j.mcp.client.protocol.McpListResourceTemplatesRequest;
import dev.langchain4j.mcp.client.protocol.McpListResourcesRequest;
import dev.langchain4j.mcp.client.protocol.McpListToolsRequest;
import dev.langchain4j.mcp.client.protocol.McpReadResourceRequest;
import dev.langchain4j.mcp.client.transport.McpOperationHandler;
import dev.langchain4j.mcp.client.transport.McpTransport;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultMcpClient
implements McpClient {
    private static final Logger log = LoggerFactory.getLogger(DefaultMcpClient.class);
    private final AtomicLong idGenerator = new AtomicLong(0L);
    private final McpTransport transport;
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private final String clientName;
    private final String clientVersion;
    private final String protocolVersion;
    private final Duration toolExecutionTimeout;
    private final Duration resourcesTimeout;
    private final JsonNode RESULT_TIMEOUT;
    private final String toolExecutionTimeoutErrorMessage;
    private final Map<Long, CompletableFuture<JsonNode>> pendingOperations = new ConcurrentHashMap<Long, CompletableFuture<JsonNode>>();
    private final McpOperationHandler messageHandler;
    private final McpLogMessageHandler logHandler;
    private final AtomicReference<List<ResourceRef>> resourceRefs = new AtomicReference();
    private final AtomicReference<List<ResourceTemplateRef>> resourceTemplateRefs = new AtomicReference();

    public DefaultMcpClient(Builder builder) {
        this.transport = (McpTransport)ValidationUtils.ensureNotNull((Object)builder.transport, (String)"transport");
        this.clientName = (String)Utils.getOrDefault((Object)builder.clientName, (Object)"langchain4j");
        this.clientVersion = (String)Utils.getOrDefault((Object)builder.clientVersion, (Object)"1.0");
        this.protocolVersion = (String)Utils.getOrDefault((Object)builder.protocolVersion, (Object)"2024-11-05");
        this.toolExecutionTimeout = (Duration)Utils.getOrDefault((Object)builder.toolExecutionTimeout, (Object)Duration.ofSeconds(60L));
        this.resourcesTimeout = (Duration)Utils.getOrDefault((Object)builder.resourcesTimeout, (Object)Duration.ofSeconds(60L));
        this.logHandler = (McpLogMessageHandler)Utils.getOrDefault((Object)builder.logHandler, (Object)new DefaultMcpLogMessageHandler());
        this.toolExecutionTimeoutErrorMessage = (String)Utils.getOrDefault((Object)builder.toolExecutionTimeoutErrorMessage, (Object)"There was a timeout executing the tool");
        this.RESULT_TIMEOUT = JsonNodeFactory.instance.objectNode();
        this.messageHandler = new McpOperationHandler(this.pendingOperations, this.transport, this.logHandler::handleLogMessage);
        ((ObjectNode)this.RESULT_TIMEOUT).putObject("result").putArray("content").addObject().put("type", "text").put("text", this.toolExecutionTimeoutErrorMessage);
        this.initialize();
    }

    private void initialize() {
        this.transport.start(this.messageHandler);
        long operationId = this.idGenerator.getAndIncrement();
        McpInitializeRequest request = new McpInitializeRequest(operationId);
        InitializeParams params = this.createInitializeParams();
        request.setParams(params);
        try {
            JsonNode capabilities = this.transport.initialize(request).get();
            log.debug("MCP server capabilities: {}", (Object)capabilities.get("result"));
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operationId);
        }
    }

    private InitializeParams createInitializeParams() {
        InitializeParams params = new InitializeParams();
        params.setProtocolVersion(this.protocolVersion);
        InitializeParams.ClientInfo clientInfo = new InitializeParams.ClientInfo();
        clientInfo.setName(this.clientName);
        clientInfo.setVersion(this.clientVersion);
        params.setClientInfo(clientInfo);
        InitializeParams.Capabilities capabilities = new InitializeParams.Capabilities();
        InitializeParams.Capabilities.Roots roots = new InitializeParams.Capabilities.Roots();
        roots.setListChanged(false);
        capabilities.setRoots(roots);
        params.setCapabilities(capabilities);
        return params;
    }

    @Override
    public List<ToolSpecification> listTools() {
        McpListToolsRequest operation = new McpListToolsRequest(this.idGenerator.getAndIncrement());
        CompletableFuture<JsonNode> resultFuture = this.transport.executeOperationWithResponse(operation);
        JsonNode result = null;
        try {
            result = resultFuture.get();
        }
        catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operation.getId());
        }
        return ToolSpecificationHelper.toolSpecificationListFromMcpResponse((ArrayNode)result.get("result").get("tools"));
    }

    @Override
    public String executeTool(ToolExecutionRequest executionRequest) {
        ObjectNode arguments = null;
        try {
            arguments = (ObjectNode)OBJECT_MAPPER.readValue(executionRequest.arguments(), ObjectNode.class);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
        long operationId = this.idGenerator.getAndIncrement();
        McpCallToolRequest operation = new McpCallToolRequest(operationId, executionRequest.name(), arguments);
        long timeoutMillis = this.toolExecutionTimeout.toMillis() == 0L ? Integer.MAX_VALUE : this.toolExecutionTimeout.toMillis();
        CompletableFuture<JsonNode> resultFuture = null;
        JsonNode result = null;
        try {
            resultFuture = this.transport.executeOperationWithResponse(operation);
            result = resultFuture.get(timeoutMillis, TimeUnit.MILLISECONDS);
        }
        catch (TimeoutException timeout) {
            this.transport.executeOperationWithoutResponse(new CancellationNotification(operationId, "Timeout"));
            String string = ToolExecutionHelper.extractResult(this.RESULT_TIMEOUT);
            return string;
        }
        catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operationId);
        }
        return ToolExecutionHelper.extractResult(result);
    }

    @Override
    public List<ResourceRef> listResources() {
        if (this.resourceRefs.get() == null) {
            this.obtainResourceList();
        }
        return this.resourceRefs.get();
    }

    @Override
    public ResourceResponse readResource(String uri) {
        long operationId = this.idGenerator.getAndIncrement();
        McpReadResourceRequest operation = new McpReadResourceRequest(operationId, uri);
        long timeoutMillis = this.resourcesTimeout.toMillis() == 0L ? Integer.MAX_VALUE : this.resourcesTimeout.toMillis();
        JsonNode result = null;
        CompletableFuture<JsonNode> resultFuture = null;
        try {
            resultFuture = this.transport.executeOperationWithResponse(operation);
            result = resultFuture.get(timeoutMillis, TimeUnit.MILLISECONDS);
            ResourceResponse resourceResponse = ResourcesHelper.parseResourceContents(result);
            return resourceResponse;
        }
        catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operationId);
        }
    }

    @Override
    public List<ResourceTemplateRef> listResourceTemplates() {
        if (this.resourceTemplateRefs.get() == null) {
            this.obtainResourceTemplateList();
        }
        return this.resourceTemplateRefs.get();
    }

    private synchronized void obtainResourceList() {
        if (this.resourceRefs.get() != null) {
            return;
        }
        McpListResourcesRequest operation = new McpListResourcesRequest(this.idGenerator.getAndIncrement());
        long timeoutMillis = this.resourcesTimeout.toMillis() == 0L ? Integer.MAX_VALUE : this.resourcesTimeout.toMillis();
        JsonNode result = null;
        CompletableFuture<JsonNode> resultFuture = null;
        try {
            resultFuture = this.transport.executeOperationWithResponse(operation);
            result = resultFuture.get(timeoutMillis, TimeUnit.MILLISECONDS);
            this.resourceRefs.set(ResourcesHelper.parseResourceRefs(result));
        }
        catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operation.getId());
        }
    }

    private synchronized void obtainResourceTemplateList() {
        if (this.resourceTemplateRefs.get() != null) {
            return;
        }
        McpListResourceTemplatesRequest operation = new McpListResourceTemplatesRequest(this.idGenerator.getAndIncrement());
        long timeoutMillis = this.toolExecutionTimeout.toMillis() == 0L ? Integer.MAX_VALUE : this.toolExecutionTimeout.toMillis();
        JsonNode result = null;
        CompletableFuture<JsonNode> resultFuture = null;
        try {
            resultFuture = this.transport.executeOperationWithResponse(operation);
            result = resultFuture.get(timeoutMillis, TimeUnit.MILLISECONDS);
            this.resourceTemplateRefs.set(ResourcesHelper.parseResourceTemplateRefs(result));
        }
        catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operation.getId());
        }
    }

    @Override
    public void close() {
        try {
            this.transport.close();
        }
        catch (Exception e) {
            log.warn("Cannot close MCP transport", (Throwable)e);
        }
    }

    public static class Builder {
        private String toolExecutionTimeoutErrorMessage;
        private McpTransport transport;
        private String clientName;
        private String clientVersion;
        private String protocolVersion;
        private Duration toolExecutionTimeout;
        private Duration resourcesTimeout;
        private McpLogMessageHandler logHandler;

        public Builder transport(McpTransport transport) {
            this.transport = transport;
            return this;
        }

        public Builder clientName(String clientName) {
            this.clientName = clientName;
            return this;
        }

        public Builder clientVersion(String clientVersion) {
            this.clientVersion = clientVersion;
            return this;
        }

        public Builder protocolVersion(String protocolVersion) {
            this.protocolVersion = protocolVersion;
            return this;
        }

        public Builder toolExecutionTimeout(Duration toolExecutionTimeout) {
            this.toolExecutionTimeout = toolExecutionTimeout;
            return this;
        }

        public Builder resourcesTimeout(Duration resourcesTimeout) {
            this.resourcesTimeout = resourcesTimeout;
            return this;
        }

        public Builder toolExecutionTimeoutErrorMessage(String toolExecutionTimeoutErrorMessage) {
            this.toolExecutionTimeoutErrorMessage = toolExecutionTimeoutErrorMessage;
            return this;
        }

        public Builder logHandler(McpLogMessageHandler logHandler) {
            this.logHandler = logHandler;
            return this;
        }

        public DefaultMcpClient build() {
            return new DefaultMcpClient(this);
        }
    }
}

