/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.mcp.client.transport.websocket;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.mcp.client.protocol.McpClientMessage;
import dev.langchain4j.mcp.client.protocol.McpInitializationNotification;
import dev.langchain4j.mcp.client.protocol.McpInitializeRequest;
import dev.langchain4j.mcp.client.transport.McpOperationHandler;
import dev.langchain4j.mcp.client.transport.McpTransport;
import dev.langchain4j.mcp.client.transport.websocket.WebSocketMcpListener;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.net.ConnectException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.WebSocket;
import java.time.Duration;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import javax.net.ssl.SSLContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WebSocketMcpTransport
implements McpTransport {
    private static final Logger DEFAULT_TRAFFIC_LOG = LoggerFactory.getLogger((String)"MCP");
    private static final Logger LOG = LoggerFactory.getLogger(WebSocketMcpTransport.class);
    private final String url;
    private final Supplier<Map<String, String>> headersSupplier;
    private final boolean logResponses;
    private final boolean logRequests;
    private final Logger trafficLog;
    static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private volatile McpOperationHandler operationHandler;
    private volatile McpInitializeRequest initializeRequest;
    private final Duration connectTimeout;
    private volatile SSLContext sslContext;
    private volatile HttpClient httpClient;
    private final Executor executor;
    private final AtomicReference<CompletableFuture<WebSocket>> webSocketRef = new AtomicReference();
    private volatile boolean closed = false;
    private volatile Runnable actionOnFailure;

    public WebSocketMcpTransport(Builder builder) {
        this.url = (String)ValidationUtils.ensureNotNull((Object)builder.url, (String)"Missing server endpoint URL");
        this.logResponses = builder.logResponses;
        this.logRequests = builder.logRequests;
        this.trafficLog = (Logger)Utils.getOrDefault((Object)builder.logger, (Object)DEFAULT_TRAFFIC_LOG);
        this.connectTimeout = (Duration)Utils.getOrDefault((Object)builder.timeout, (Object)Duration.ofSeconds(60L));
        this.headersSupplier = (Supplier)Utils.getOrDefault(builder.headersSupplier, () -> Map.of());
        this.executor = builder.executor;
        this.httpClient = this.createHttpClient();
    }

    private HttpClient createHttpClient() {
        HttpClient.Builder clientBuilder = HttpClient.newBuilder().connectTimeout(this.connectTimeout);
        if (this.sslContext != null) {
            clientBuilder.sslContext(this.sslContext);
        }
        if (this.executor != null) {
            clientBuilder.executor(this.executor);
        }
        return clientBuilder.build();
    }

    private synchronized WebSocket getWebSocket() {
        try {
            CompletableFuture<WebSocket> future = this.webSocketRef.get();
            if (future == null) {
                return this.startWebSocket().get();
            }
            return future.get();
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
        catch (ExecutionException e) {
            try {
                return this.startWebSocket().get();
            }
            catch (InterruptedException ex) {
                Thread.currentThread().interrupt();
                throw new RuntimeException(ex);
            }
            catch (ExecutionException ex) {
                throw new RuntimeException(ex);
            }
        }
    }

    @Override
    public void start(McpOperationHandler operationHandler) {
        this.operationHandler = operationHandler;
        this.startWebSocket();
    }

    private synchronized CompletableFuture<WebSocket> startWebSocket() {
        CompletableFuture<WebSocket> current = this.webSocketRef.get();
        if (current != null && !current.isDone()) {
            return current;
        }
        WebSocket.Builder builder = this.httpClient.newWebSocketBuilder();
        this.headersSupplier.get().forEach((key, value) -> builder.header((String)key, (String)value));
        builder.connectTimeout(this.connectTimeout);
        CompletionStage<WebSocket> newWebSocketFuture = builder.buildAsync(URI.create(this.url), new WebSocketMcpListener(this.operationHandler, this.trafficLog, this.logResponses, () -> this.webSocketRef.set(null), this.actionOnFailure));
        if (this.initializeRequest != null) {
            newWebSocketFuture = newWebSocketFuture.thenCompose(webSocket -> this.execute(this.initializeRequest, this.initializeRequest.getId(), Optional.of(webSocket)).thenCompose(originalResponse -> this.execute(new McpInitializationNotification(), null, Optional.of(webSocket)).thenCompose(nullNode -> CompletableFuture.completedFuture(webSocket))));
        }
        this.webSocketRef.set((CompletableFuture<WebSocket>)newWebSocketFuture);
        return newWebSocketFuture;
    }

    @Override
    public CompletableFuture<JsonNode> initialize(McpInitializeRequest operation) {
        this.initializeRequest = operation;
        CompletableFuture<JsonNode> completableFuture = this.execute(operation, operation.getId(), Optional.empty());
        return ((CompletableFuture)completableFuture.thenCompose(originalResponse -> CompletableFuture.completedFuture(originalResponse))).thenCompose(originalResponse -> this.execute(new McpInitializationNotification(), null, Optional.empty()).thenCompose(nullNode -> CompletableFuture.completedFuture(originalResponse)));
    }

    @Override
    public CompletableFuture<JsonNode> executeOperationWithResponse(McpClientMessage request) {
        return this.execute(request, request.getId(), Optional.empty());
    }

    @Override
    public void executeOperationWithoutResponse(McpClientMessage request) {
        this.execute(request, null, Optional.empty());
    }

    @Override
    public void checkHealth() {
    }

    @Override
    public void onFailure(Runnable actionOnFailure) {
        this.actionOnFailure = actionOnFailure;
    }

    @Override
    public void close() throws IOException {
        block6: {
            this.closed = true;
            CompletableFuture<WebSocket> future = this.webSocketRef.get();
            if (future != null && future.isDone()) {
                try {
                    WebSocket webSocket = future.get();
                    webSocket.sendClose(1000, "Client closing").thenRun(() -> LOG.info("WebSocket connection closed"));
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    throw new RuntimeException(e);
                }
                catch (ExecutionException e) {
                    if (e.getCause() != null && e.getCause() instanceof ConnectException) break block6;
                    LOG.warn("Failed to close WebSocket connection", (Throwable)e);
                }
            }
        }
        try {
            this.httpClient.getClass().getMethod("close", new Class[0]).invoke((Object)this.httpClient, new Object[0]);
        }
        catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException reflectiveOperationException) {
            // empty catch block
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private CompletableFuture<JsonNode> execute(McpClientMessage message, Long id, Optional<WebSocket> webSocket) {
        CompletableFuture<JsonNode> future = new CompletableFuture<JsonNode>();
        if (this.closed) {
            future.completeExceptionally(new IllegalStateException("Transport is closed"));
            return future;
        }
        if (id != null) {
            this.operationHandler.startOperation(id, future);
        }
        try {
            String messageJson = OBJECT_MAPPER.writeValueAsString((Object)message);
            WebSocket wsToUse = webSocket.orElseGet(() -> this.getWebSocket());
            if (this.logRequests) {
                this.trafficLog.info("> " + messageJson);
            }
            WebSocket webSocket2 = wsToUse;
            synchronized (webSocket2) {
                wsToUse.sendText(messageJson, true).thenAccept(ws -> {
                    if (id == null) {
                        future.complete(null);
                    }
                });
            }
        }
        catch (Exception e) {
            future.completeExceptionally(e);
        }
        return future;
    }

    public void reloadSslContext(SSLContext sslContext) {
        ValidationUtils.ensureNotNull((Object)sslContext, (String)"sslContext");
        this.sslContext = sslContext;
        this.httpClient = this.createHttpClient();
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private boolean logResponses;
        private boolean logRequests;
        private String url;
        private Logger logger;
        private Executor executor;
        private Duration timeout;
        private SSLContext sslContext;
        private Supplier<Map<String, String>> headersSupplier;

        public Builder logResponses(boolean logResponses) {
            this.logResponses = logResponses;
            return this;
        }

        public Builder logRequests(boolean logRequests) {
            this.logRequests = logRequests;
            return this;
        }

        public Builder url(String url) {
            this.url = url;
            return this;
        }

        public Builder logger(Logger logger) {
            this.logger = logger;
            return this;
        }

        public Builder executor(Executor executor) {
            this.executor = executor;
            return this;
        }

        public Builder timeout(Duration timeout) {
            this.timeout = timeout;
            return this;
        }

        public Builder sslContext(SSLContext sslContext) {
            this.sslContext = sslContext;
            return this;
        }

        public Builder headersSupplier(Supplier<Map<String, String>> headersSupplier) {
            this.headersSupplier = headersSupplier;
            return this;
        }

        public WebSocketMcpTransport build() {
            return new WebSocketMcpTransport(this);
        }
    }
}

