/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.client.transport;

import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.spec.DefaultMcpTransportSession;
import io.modelcontextprotocol.spec.DefaultMcpTransportStream;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpTransportException;
import io.modelcontextprotocol.spec.McpTransportSession;
import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
import io.modelcontextprotocol.spec.McpTransportStream;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;

public class WebClientStreamableHttpTransport
implements McpClientTransport {
    private static final String MISSING_SESSION_ID = "[missing_session_id]";
    private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class);
    private static final String MCP_PROTOCOL_VERSION = "2025-06-18";
    private static final String DEFAULT_ENDPOINT = "/mcp";
    private static final String MESSAGE_EVENT_TYPE = "message";
    private static final ParameterizedTypeReference<ServerSentEvent<String>> PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference<ServerSentEvent<String>>(){};
    private final McpJsonMapper jsonMapper;
    private final WebClient webClient;
    private final String endpoint;
    private final boolean openConnectionOnStartup;
    private final boolean resumableStreams;
    private final AtomicReference<DefaultMcpTransportSession> activeSession = new AtomicReference();
    private final AtomicReference<Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>>> handler = new AtomicReference();
    private final AtomicReference<Consumer<Throwable>> exceptionHandler = new AtomicReference();

    private WebClientStreamableHttpTransport(McpJsonMapper jsonMapper, WebClient.Builder webClientBuilder, String endpoint, boolean resumableStreams, boolean openConnectionOnStartup) {
        this.jsonMapper = jsonMapper;
        this.webClient = webClientBuilder.build();
        this.endpoint = endpoint;
        this.resumableStreams = resumableStreams;
        this.openConnectionOnStartup = openConnectionOnStartup;
        this.activeSession.set(this.createTransportSession());
    }

    public List<String> protocolVersions() {
        return List.of("2024-11-05", "2025-03-26", MCP_PROTOCOL_VERSION);
    }

    public static Builder builder(WebClient.Builder webClientBuilder) {
        return new Builder(webClientBuilder);
    }

    public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
        return Mono.deferContextual(ctx -> {
            this.handler.set(handler);
            if (this.openConnectionOnStartup) {
                logger.debug("Eagerly opening connection on startup");
                return this.reconnect(null).then();
            }
            return Mono.empty();
        });
    }

    private DefaultMcpTransportSession createTransportSession() {
        Function<String, Publisher> onClose = sessionId -> sessionId == null ? Mono.empty() : this.webClient.delete().uri(this.endpoint, new Object[0]).header("Mcp-Session-Id", new String[]{sessionId}).header("MCP-Protocol-Version", new String[]{MCP_PROTOCOL_VERSION}).retrieve().toBodilessEntity().onErrorComplete(e -> {
            logger.warn("Got error when closing transport", e);
            return true;
        }).then();
        return new DefaultMcpTransportSession(onClose);
    }

    public void setExceptionHandler(Consumer<Throwable> handler) {
        logger.debug("Exception handler registered");
        this.exceptionHandler.set(handler);
    }

    private void handleException(Throwable t) {
        Consumer<Throwable> handler;
        logger.debug("Handling exception for session {}", (Object)WebClientStreamableHttpTransport.sessionIdOrPlaceholder((McpTransportSession)this.activeSession.get()), (Object)t);
        if (t instanceof McpTransportSessionNotFoundException) {
            McpTransportSession invalidSession = (McpTransportSession)this.activeSession.getAndSet(this.createTransportSession());
            logger.warn("Server does not recognize session {}. Invalidating.", (Object)invalidSession.sessionId());
            invalidSession.close();
        }
        if ((handler = this.exceptionHandler.get()) != null) {
            handler.accept(t);
        }
    }

    public Mono<Void> closeGracefully() {
        return Mono.defer(() -> {
            logger.debug("Graceful close triggered");
            DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(this.createTransportSession());
            if (currentSession != null) {
                return currentSession.closeGracefully();
            }
            return Mono.empty();
        });
    }

    private Mono<Disposable> reconnect(McpTransportStream<Disposable> stream) {
        return Mono.deferContextual(ctx -> {
            if (stream != null) {
                logger.debug("Reconnecting stream {} with lastId {}", (Object)stream.streamId(), (Object)stream.lastId());
            } else {
                logger.debug("Reconnecting with no prior stream");
            }
            AtomicReference<Disposable> disposableRef = new AtomicReference<Disposable>();
            McpTransportSession transportSession = (McpTransportSession)this.activeSession.get();
            Disposable connection = this.webClient.get().uri(this.endpoint, new Object[0]).accept(new MediaType[]{MediaType.TEXT_EVENT_STREAM}).header("MCP-Protocol-Version", new String[]{MCP_PROTOCOL_VERSION}).headers(httpHeaders -> {
                transportSession.sessionId().ifPresent(id -> httpHeaders.add("Mcp-Session-Id", id));
                if (stream != null) {
                    stream.lastId().ifPresent(id -> httpHeaders.add("Last-Event-ID", id));
                }
            }).exchangeToFlux(response -> {
                if (WebClientStreamableHttpTransport.isEventStream(response)) {
                    logger.debug("Established SSE stream via GET");
                    return this.eventStream(stream, (ClientResponse)response);
                }
                if (WebClientStreamableHttpTransport.isNotAllowed(response)) {
                    logger.debug("The server does not support SSE streams, using request-response mode.");
                    return Flux.empty();
                }
                if (WebClientStreamableHttpTransport.isNotFound(response)) {
                    if (transportSession.sessionId().isPresent()) {
                        String sessionIdRepresentation = WebClientStreamableHttpTransport.sessionIdOrPlaceholder(transportSession);
                        return WebClientStreamableHttpTransport.mcpSessionNotFoundError(sessionIdRepresentation);
                    }
                    return this.extractError((ClientResponse)response, MISSING_SESSION_ID);
                }
                return response.createError().doOnError(e -> logger.info("Opening an SSE stream failed. This can be safely ignored.", e)).flux();
            }).flatMap(jsonrpcMessage -> (Publisher)this.handler.get().apply((Mono<McpSchema.JSONRPCMessage>)Mono.just((Object)jsonrpcMessage))).onErrorComplete(t -> {
                this.handleException((Throwable)t);
                return true;
            }).doFinally(s -> {
                Disposable ref = disposableRef.getAndSet(null);
                if (ref != null) {
                    transportSession.removeConnection((Object)ref);
                }
            }).contextWrite(ctx).subscribe();
            disposableRef.set(connection);
            transportSession.addConnection((Object)connection);
            return Mono.just((Object)connection);
        });
    }

    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
        return Mono.create(sink -> {
            logger.debug("Sending message {}", (Object)message);
            AtomicReference<Disposable> disposableRef = new AtomicReference<Disposable>();
            McpTransportSession transportSession = (McpTransportSession)this.activeSession.get();
            Disposable connection = ((WebClient.RequestBodySpec)((WebClient.RequestBodySpec)((WebClient.RequestBodySpec)((WebClient.RequestBodySpec)this.webClient.post().uri(this.endpoint, new Object[0])).accept(new MediaType[]{MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM})).header("MCP-Protocol-Version", new String[]{MCP_PROTOCOL_VERSION})).headers(httpHeaders -> transportSession.sessionId().ifPresent(id -> httpHeaders.add("Mcp-Session-Id", id)))).bodyValue((Object)message).exchangeToFlux(response -> {
                if (transportSession.markInitialized(response.headers().asHttpHeaders().getFirst("Mcp-Session-Id"))) {
                    this.reconnect(null).contextWrite(sink.contextView()).subscribe();
                }
                String sessionRepresentation = WebClientStreamableHttpTransport.sessionIdOrPlaceholder(transportSession);
                if (response.statusCode().is2xxSuccessful()) {
                    Optional contentType = response.headers().contentType();
                    long contentLength = response.headers().contentLength().orElse(-1L);
                    if (contentType.isEmpty() || contentLength == 0L) {
                        logger.trace("Message was successfully sent via POST for session {}", (Object)sessionRepresentation);
                        sink.success();
                        return Flux.empty();
                    }
                    MediaType mediaType = (MediaType)contentType.get();
                    if (mediaType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) {
                        logger.debug("Established SSE stream via POST");
                        sink.success();
                        return this.newEventStream((ClientResponse)response, sessionRepresentation);
                    }
                    if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
                        logger.trace("Received response to POST for session {}", (Object)sessionRepresentation);
                        sink.success();
                        return this.directResponseFlux(message, (ClientResponse)response);
                    }
                    logger.warn("Unknown media type {} returned for POST in session {}", (Object)contentType, (Object)sessionRepresentation);
                    return Flux.error((Throwable)new RuntimeException("Unknown media type returned: " + String.valueOf(contentType)));
                }
                if (WebClientStreamableHttpTransport.isNotFound(response) && !sessionRepresentation.equals(MISSING_SESSION_ID)) {
                    return WebClientStreamableHttpTransport.mcpSessionNotFoundError(sessionRepresentation);
                }
                return this.extractError((ClientResponse)response, sessionRepresentation);
            }).flatMap(jsonRpcMessage -> (Publisher)this.handler.get().apply((Mono<McpSchema.JSONRPCMessage>)Mono.just((Object)jsonRpcMessage))).onErrorComplete(t -> {
                this.handleException((Throwable)t);
                sink.error(t);
                return true;
            }).doFinally(s -> {
                Disposable ref = disposableRef.getAndSet(null);
                if (ref != null) {
                    transportSession.removeConnection((Object)ref);
                }
            }).contextWrite(sink.contextView()).subscribe();
            disposableRef.set(connection);
            transportSession.addConnection((Object)connection);
        });
    }

    private static Flux<McpSchema.JSONRPCMessage> mcpSessionNotFoundError(String sessionRepresentation) {
        logger.warn("Session {} was not found on the MCP server", (Object)sessionRepresentation);
        return Flux.error((Throwable)new McpTransportSessionNotFoundException(sessionRepresentation));
    }

    private Flux<McpSchema.JSONRPCMessage> extractError(ClientResponse response, String sessionRepresentation) {
        return response.createError().onErrorResume(e -> {
            McpTransportException toPropagate;
            WebClientResponseException responseException = (WebClientResponseException)e;
            byte[] body = responseException.getResponseBodyAsByteArray();
            McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null;
            try {
                McpSchema.JSONRPCResponse jsonRpcResponse = (McpSchema.JSONRPCResponse)this.jsonMapper.readValue(body, McpSchema.JSONRPCResponse.class);
                jsonRpcError = jsonRpcResponse.error();
                toPropagate = jsonRpcError != null ? new McpError(jsonRpcError) : new McpTransportException("Can't parse the jsonResponse " + String.valueOf(jsonRpcResponse));
            }
            catch (IOException ex) {
                toPropagate = new McpTransportException("Sending request failed, " + e.getMessage(), e);
                logger.debug("Received content together with {} HTTP code response: {}", (Object)response.statusCode(), (Object)body);
            }
            if (responseException.getStatusCode().isSameCodeAs((HttpStatusCode)HttpStatus.BAD_REQUEST)) {
                if (!sessionRepresentation.equals(MISSING_SESSION_ID)) {
                    return Mono.error((Throwable)new McpTransportSessionNotFoundException(sessionRepresentation, (Exception)toPropagate));
                }
                return Mono.error((Throwable)new McpTransportException("Received 400 BAD REQUEST for session " + sessionRepresentation + ". " + toPropagate.getMessage(), (Throwable)toPropagate));
            }
            return Mono.error((Throwable)toPropagate);
        }).flux();
    }

    private Flux<McpSchema.JSONRPCMessage> eventStream(McpTransportStream<Disposable> stream, ClientResponse response) {
        DefaultMcpTransportStream sessionStream = stream != null ? stream : new DefaultMcpTransportStream(this.resumableStreams, this::reconnect);
        logger.debug("Connected stream {}", (Object)sessionStream.streamId());
        Flux idWithMessages = response.bodyToFlux(PARAMETERIZED_TYPE_REF).map(this::parse);
        return Flux.from((Publisher)sessionStream.consumeSseStream((Publisher)idWithMessages));
    }

    private static boolean isNotFound(ClientResponse response) {
        return response.statusCode().isSameCodeAs((HttpStatusCode)HttpStatus.NOT_FOUND);
    }

    private static boolean isNotAllowed(ClientResponse response) {
        return response.statusCode().isSameCodeAs((HttpStatusCode)HttpStatus.METHOD_NOT_ALLOWED);
    }

    private static boolean isEventStream(ClientResponse response) {
        return response.statusCode().is2xxSuccessful() && response.headers().contentType().isPresent() && ((MediaType)response.headers().contentType().get()).isCompatibleWith(MediaType.TEXT_EVENT_STREAM);
    }

    private static String sessionIdOrPlaceholder(McpTransportSession<?> transportSession) {
        return transportSession.sessionId().orElse(MISSING_SESSION_ID);
    }

    private Flux<McpSchema.JSONRPCMessage> directResponseFlux(McpSchema.JSONRPCMessage sentMessage, ClientResponse response) {
        return response.bodyToMono(String.class).handle((responseMessage, s) -> {
            try {
                if (sentMessage instanceof McpSchema.JSONRPCNotification) {
                    logger.warn("Notification: {} received non-compliant response: {}", (Object)sentMessage, (Object)(Utils.hasText((String)responseMessage) ? responseMessage : "[empty]"));
                    s.complete();
                } else {
                    McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage((McpJsonMapper)this.jsonMapper, (String)responseMessage);
                    s.next(List.of(jsonRpcResponse));
                }
            }
            catch (IOException e) {
                s.error((Throwable)new McpTransportException((Throwable)e));
            }
        }).flatMapIterable(Function.identity());
    }

    private Flux<McpSchema.JSONRPCMessage> newEventStream(ClientResponse response, String sessionRepresentation) {
        DefaultMcpTransportStream sessionStream = new DefaultMcpTransportStream(this.resumableStreams, this::reconnect);
        logger.trace("Sent POST and opened a stream ({}) for session {}", (Object)sessionStream.streamId(), (Object)sessionRepresentation);
        return this.eventStream((McpTransportStream<Disposable>)sessionStream, response);
    }

    public <T> T unmarshalFrom(Object data, TypeRef<T> typeRef) {
        return (T)this.jsonMapper.convertValue(data, typeRef);
    }

    private Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>> parse(ServerSentEvent<String> event) {
        if (MESSAGE_EVENT_TYPE.equals(event.event())) {
            try {
                McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage((McpJsonMapper)this.jsonMapper, (String)((String)event.data()));
                return Tuples.of(Optional.ofNullable(event.id()), List.of(message));
            }
            catch (IOException ioException) {
                throw new McpTransportException("Error parsing JSON-RPC message: " + (String)event.data(), (Throwable)ioException);
            }
        }
        logger.debug("Received SSE event with type: {}", event);
        return Tuples.of(Optional.empty(), List.of());
    }

    public static class Builder {
        private McpJsonMapper jsonMapper;
        private WebClient.Builder webClientBuilder;
        private String endpoint = "/mcp";
        private boolean resumableStreams = true;
        private boolean openConnectionOnStartup = false;

        private Builder(WebClient.Builder webClientBuilder) {
            Assert.notNull((Object)webClientBuilder, (String)"WebClient.Builder must not be null");
            this.webClientBuilder = webClientBuilder;
        }

        public Builder jsonMapper(McpJsonMapper jsonMapper) {
            Assert.notNull((Object)jsonMapper, (String)"JsonMapper must not be null");
            this.jsonMapper = jsonMapper;
            return this;
        }

        public Builder webClientBuilder(WebClient.Builder webClientBuilder) {
            Assert.notNull((Object)webClientBuilder, (String)"WebClient.Builder must not be null");
            this.webClientBuilder = webClientBuilder;
            return this;
        }

        public Builder endpoint(String endpoint) {
            Assert.hasText((String)endpoint, (String)"endpoint must be a non-empty String");
            this.endpoint = endpoint;
            return this;
        }

        public Builder resumableStreams(boolean resumableStreams) {
            this.resumableStreams = resumableStreams;
            return this;
        }

        public Builder openConnectionOnStartup(boolean openConnectionOnStartup) {
            this.openConnectionOnStartup = openConnectionOnStartup;
            return this;
        }

        public WebClientStreamableHttpTransport build() {
            return new WebClientStreamableHttpTransport(this.jsonMapper == null ? McpJsonMapper.getDefault() : this.jsonMapper, this.webClientBuilder, this.endpoint, this.resumableStreams, this.openConnectionOnStartup);
        }
    }
}

