/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.client.transport;

import io.modelcontextprotocol.json.McpJsonDefaults;
import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.util.Assert;
import java.io.IOException;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.core.publisher.SynchronousSink;
import reactor.core.scheduler.Schedulers;
import reactor.util.retry.Retry;

public class WebFluxSseClientTransport
implements McpClientTransport {
    private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class);
    private static final String MCP_PROTOCOL_VERSION = "2024-11-05";
    private static final String MESSAGE_EVENT_TYPE = "message";
    private static final String ENDPOINT_EVENT_TYPE = "endpoint";
    private static final String DEFAULT_SSE_ENDPOINT = "/sse";
    private static final ParameterizedTypeReference<ServerSentEvent<String>> SSE_TYPE = new ParameterizedTypeReference<ServerSentEvent<String>>(){};
    private final WebClient webClient;
    protected McpJsonMapper jsonMapper;
    private Disposable inboundSubscription;
    private volatile boolean isClosing = false;
    protected final Sinks.One<String> messageEndpointSink = Sinks.one();
    private String sseEndpoint;
    private BiConsumer<Retry.RetrySignal, SynchronousSink<Object>> inboundRetryHandler = (retrySpec, sink) -> {
        if (this.isClosing) {
            logger.debug("SSE connection closed during shutdown");
            sink.error(retrySpec.failure());
            return;
        }
        if (retrySpec.failure() instanceof IOException) {
            logger.debug("Retrying SSE connection after IO error");
            sink.next(retrySpec);
            return;
        }
        logger.error("Fatal SSE error, not retrying: {}", (Object)retrySpec.failure().getMessage());
        sink.error(retrySpec.failure());
    };

    public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper) {
        this(webClientBuilder, jsonMapper, DEFAULT_SSE_ENDPOINT);
    }

    public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper, String sseEndpoint) {
        Assert.notNull((Object)jsonMapper, (String)"jsonMapper must not be null");
        Assert.notNull((Object)webClientBuilder, (String)"WebClient.Builder must not be null");
        Assert.hasText((String)sseEndpoint, (String)"SSE endpoint must not be null or empty");
        this.jsonMapper = jsonMapper;
        this.webClient = webClientBuilder.build();
        this.sseEndpoint = sseEndpoint;
    }

    public List<String> protocolVersions() {
        return List.of(MCP_PROTOCOL_VERSION);
    }

    public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
        Flux<ServerSentEvent<String>> events = this.eventStream();
        this.inboundSubscription = events.concatMap(event -> Mono.just((Object)event).handle((e, s) -> {
            if (ENDPOINT_EVENT_TYPE.equals(event.event())) {
                String messageEndpointUri = (String)event.data();
                if (this.messageEndpointSink.tryEmitValue((Object)messageEndpointUri).isSuccess()) {
                    s.complete();
                } else {
                    s.error((Throwable)new RuntimeException("Failed to handle SSE endpoint event"));
                }
            } else if (MESSAGE_EVENT_TYPE.equals(event.event())) {
                try {
                    McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage((McpJsonMapper)this.jsonMapper, (String)((String)event.data()));
                    s.next((Object)message);
                }
                catch (IOException ioException) {
                    s.error((Throwable)ioException);
                }
            } else {
                logger.debug("Received unrecognized SSE event type: {}", event);
                s.complete();
            }
        }).transform(handler)).subscribe();
        return this.messageEndpointSink.asMono().then();
    }

    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
        return this.messageEndpointSink.asMono().flatMap(messageEndpointUri -> {
            if (this.isClosing) {
                return Mono.empty();
            }
            try {
                String jsonText = this.jsonMapper.writeValueAsString((Object)message);
                return ((WebClient.RequestBodySpec)((WebClient.RequestBodySpec)this.webClient.post().uri(messageEndpointUri, new Object[0])).contentType(MediaType.APPLICATION_JSON).header("MCP-Protocol-Version", new String[]{MCP_PROTOCOL_VERSION})).bodyValue((Object)jsonText).retrieve().toBodilessEntity().doOnSuccess(response -> logger.debug("Message sent successfully")).doOnError(error -> {
                    if (!this.isClosing) {
                        logger.error("Error sending message: {}", (Object)error.getMessage());
                    }
                });
            }
            catch (IOException e) {
                if (!this.isClosing) {
                    return Mono.error((Throwable)new RuntimeException("Failed to serialize message", e));
                }
                return Mono.empty();
            }
        }).then();
    }

    protected Flux<ServerSentEvent<String>> eventStream() {
        return this.webClient.get().uri(this.sseEndpoint, new Object[0]).accept(new MediaType[]{MediaType.TEXT_EVENT_STREAM}).header("MCP-Protocol-Version", new String[]{MCP_PROTOCOL_VERSION}).retrieve().bodyToFlux(SSE_TYPE).retryWhen(Retry.from(retrySignal -> retrySignal.handle(this.inboundRetryHandler)));
    }

    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(() -> {
            this.isClosing = true;
            if (this.inboundSubscription != null) {
                this.inboundSubscription.dispose();
            }
        }).then().subscribeOn(Schedulers.boundedElastic());
    }

    public <T> T unmarshalFrom(Object data, TypeRef<T> typeRef) {
        return (T)this.jsonMapper.convertValue(data, typeRef);
    }

    public static Builder builder(WebClient.Builder webClientBuilder) {
        return new Builder(webClientBuilder);
    }

    public static class Builder {
        private final WebClient.Builder webClientBuilder;
        private String sseEndpoint = "/sse";
        private McpJsonMapper jsonMapper;

        public Builder(WebClient.Builder webClientBuilder) {
            Assert.notNull((Object)webClientBuilder, (String)"WebClient.Builder must not be null");
            this.webClientBuilder = webClientBuilder;
        }

        public Builder sseEndpoint(String sseEndpoint) {
            Assert.hasText((String)sseEndpoint, (String)"sseEndpoint must not be empty");
            this.sseEndpoint = sseEndpoint;
            return this;
        }

        public Builder jsonMapper(McpJsonMapper jsonMapper) {
            Assert.notNull((Object)jsonMapper, (String)"jsonMapper must not be null");
            this.jsonMapper = jsonMapper;
            return this;
        }

        public WebFluxSseClientTransport build() {
            return new WebFluxSseClientTransport(this.webClientBuilder, this.jsonMapper == null ? McpJsonDefaults.getMapper() : this.jsonMapper, this.sseEndpoint);
        }
    }
}

