/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.server.transport;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.json.McpJsonDefaults;
import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.server.transport.HttpServletRequestUtils;
import io.modelcontextprotocol.server.transport.ServerTransportSecurityException;
import io.modelcontextprotocol.server.transport.ServerTransportSecurityValidator;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpStreamableServerSession;
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.KeepAliveScheduler;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncEvent;
import jakarta.servlet.AsyncListener;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.WebServlet;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

@WebServlet(asyncSupported=true)
public class HttpServletStreamableServerTransportProvider
extends HttpServlet
implements McpStreamableServerTransportProvider {
    private static final Logger logger = LoggerFactory.getLogger(HttpServletStreamableServerTransportProvider.class);
    public static final String MESSAGE_EVENT_TYPE = "message";
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    private static final String ACCEPT = "Accept";
    public static final String UTF_8 = "UTF-8";
    public static final String APPLICATION_JSON = "application/json";
    public static final String TEXT_EVENT_STREAM = "text/event-stream";
    public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}";
    private final String mcpEndpoint;
    private final boolean disallowDelete;
    private final McpJsonMapper jsonMapper;
    private McpStreamableServerSession.Factory sessionFactory;
    private final ConcurrentHashMap<String, McpStreamableServerSession> sessions = new ConcurrentHashMap();
    private McpTransportContextExtractor<HttpServletRequest> contextExtractor;
    private volatile boolean isClosing = false;
    private KeepAliveScheduler keepAliveScheduler;
    private final ServerTransportSecurityValidator securityValidator;

    private HttpServletStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, boolean disallowDelete, McpTransportContextExtractor<HttpServletRequest> contextExtractor, Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator) {
        Assert.notNull(jsonMapper, "JsonMapper must not be null");
        Assert.notNull(mcpEndpoint, "MCP endpoint must not be null");
        Assert.notNull(contextExtractor, "Context extractor must not be null");
        Assert.notNull(securityValidator, "Security validator must not be null");
        this.jsonMapper = jsonMapper;
        this.mcpEndpoint = mcpEndpoint;
        this.disallowDelete = disallowDelete;
        this.contextExtractor = contextExtractor;
        this.securityValidator = securityValidator;
        if (keepAliveInterval != null) {
            this.keepAliveScheduler = KeepAliveScheduler.builder(() -> this.isClosing ? Flux.empty() : Flux.fromIterable(this.sessions.values())).initialDelay(keepAliveInterval).interval(keepAliveInterval).build();
            this.keepAliveScheduler.start();
        }
    }

    @Override
    public List<String> protocolVersions() {
        return List.of("2024-11-05", "2025-03-26", "2025-06-18", "2025-11-25");
    }

    @Override
    public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    @Override
    public Mono<Void> notifyClients(String method, Object params) {
        if (this.sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        logger.debug("Attempting to broadcast message to {} active sessions", (Object)this.sessions.size());
        return Mono.fromRunnable(() -> this.sessions.values().parallelStream().forEach(session -> {
            try {
                session.sendNotification(method, params).block();
            }
            catch (Exception e) {
                logger.error("Failed to send message to session {}: {}", (Object)session.getId(), (Object)e.getMessage());
            }
        }));
    }

    @Override
    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(() -> {
            this.isClosing = true;
            logger.debug("Initiating graceful shutdown with {} active sessions", (Object)this.sessions.size());
            this.sessions.values().parallelStream().forEach(session -> {
                try {
                    session.closeGracefully().block();
                }
                catch (Exception e) {
                    logger.error("Failed to close session {}: {}", (Object)session.getId(), (Object)e.getMessage());
                }
            });
            this.sessions.clear();
            logger.debug("Graceful shutdown completed");
        }).then().doOnSuccess(v -> {
            this.sessions.clear();
            logger.debug("Graceful shutdown completed");
            if (this.keepAliveScheduler != null) {
                this.keepAliveScheduler.shutdown();
            }
        });
    }

    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        String sessionId;
        String requestURI = request.getRequestURI();
        if (!requestURI.endsWith(this.mcpEndpoint)) {
            response.sendError(404);
            return;
        }
        if (this.isClosing) {
            response.sendError(503, "Server is shutting down");
            return;
        }
        try {
            Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
            this.securityValidator.validateHeaders(headers);
        }
        catch (ServerTransportSecurityException e) {
            response.sendError(e.getStatusCode(), e.getMessage());
            return;
        }
        ArrayList<String> badRequestErrors = new ArrayList<String>();
        String accept = request.getHeader(ACCEPT);
        if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) {
            badRequestErrors.add("text/event-stream required in Accept header");
        }
        if ((sessionId = request.getHeader("Mcp-Session-Id")) == null || sessionId.isBlank()) {
            badRequestErrors.add("Session ID required in mcp-session-id header");
        }
        if (!badRequestErrors.isEmpty()) {
            String combinedMessage = String.join((CharSequence)"; ", badRequestErrors);
            this.responseError(response, 400, McpError.builder(-32601).message(combinedMessage).build());
            return;
        }
        McpStreamableServerSession session = this.sessions.get(sessionId);
        if (session == null) {
            response.sendError(404);
            return;
        }
        logger.debug("Handling GET request for session: {}", (Object)sessionId);
        McpTransportContext transportContext = this.contextExtractor.extract(request);
        try {
            response.setContentType(TEXT_EVENT_STREAM);
            response.setCharacterEncoding(UTF_8);
            response.setHeader("Cache-Control", "no-cache");
            response.setHeader("Connection", "keep-alive");
            response.setHeader("Access-Control-Allow-Origin", "*");
            AsyncContext asyncContext = request.startAsync();
            asyncContext.setTimeout(0L);
            HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport(sessionId, asyncContext, response.getWriter());
            if (request.getHeader("Last-Event-ID") != null) {
                String lastId = request.getHeader("Last-Event-ID");
                try {
                    session.replay(lastId).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext)).toIterable().forEach(message -> {
                        try {
                            sessionTransport.sendMessage((McpSchema.JSONRPCMessage)message).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext)).block();
                        }
                        catch (Exception e) {
                            logger.error("Failed to replay message: {}", (Object)e.getMessage());
                            asyncContext.complete();
                        }
                    });
                }
                catch (Exception e) {
                    logger.error("Failed to replay messages: {}", (Object)e.getMessage());
                    asyncContext.complete();
                }
            } else {
                final McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session.listeningStream(sessionTransport);
                asyncContext.addListener(new AsyncListener(){

                    public void onComplete(AsyncEvent event) throws IOException {
                        logger.debug("SSE connection completed for session: {}", (Object)sessionId);
                        listeningStream.close();
                    }

                    public void onTimeout(AsyncEvent event) throws IOException {
                        logger.debug("SSE connection timed out for session: {}", (Object)sessionId);
                        listeningStream.close();
                    }

                    public void onError(AsyncEvent event) throws IOException {
                        logger.debug("SSE connection error for session: {}", (Object)sessionId);
                        listeningStream.close();
                    }

                    public void onStartAsync(AsyncEvent event) throws IOException {
                    }
                });
            }
        }
        catch (Exception e) {
            logger.error("Failed to handle GET request for session {}: {}", (Object)sessionId, (Object)e.getMessage());
            response.sendError(500);
        }
    }

    protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        String requestURI = request.getRequestURI();
        if (!requestURI.endsWith(this.mcpEndpoint)) {
            response.sendError(404);
            return;
        }
        if (this.isClosing) {
            response.sendError(503, "Server is shutting down");
            return;
        }
        try {
            Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
            this.securityValidator.validateHeaders(headers);
        }
        catch (ServerTransportSecurityException e) {
            response.sendError(e.getStatusCode(), e.getMessage());
            return;
        }
        ArrayList<String> badRequestErrors = new ArrayList<String>();
        String accept = request.getHeader(ACCEPT);
        if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) {
            badRequestErrors.add("text/event-stream required in Accept header");
        }
        if (accept == null || !accept.contains(APPLICATION_JSON)) {
            badRequestErrors.add("application/json required in Accept header");
        }
        McpTransportContext transportContext = this.contextExtractor.extract(request);
        try {
            McpSchema.JSONRPCRequest jsonrpcRequest;
            String line;
            BufferedReader reader = request.getReader();
            StringBuilder body = new StringBuilder();
            while ((line = reader.readLine()) != null) {
                body.append(line);
            }
            McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, body.toString());
            if (message instanceof McpSchema.JSONRPCRequest && (jsonrpcRequest = (McpSchema.JSONRPCRequest)message).method().equals("initialize")) {
                if (!badRequestErrors.isEmpty()) {
                    String combinedMessage = String.join((CharSequence)"; ", badRequestErrors);
                    this.responseError(response, 400, McpError.builder(-32601).message(combinedMessage).build());
                    return;
                }
                McpSchema.InitializeRequest initializeRequest = this.jsonMapper.convertValue(jsonrpcRequest.params(), new TypeRef<McpSchema.InitializeRequest>(){});
                McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory.startSession(initializeRequest);
                this.sessions.put(init.session().getId(), init.session());
                try {
                    McpSchema.InitializeResult initResult = (McpSchema.InitializeResult)init.initResult().block();
                    response.setContentType(APPLICATION_JSON);
                    response.setCharacterEncoding(UTF_8);
                    response.setHeader("Mcp-Session-Id", init.session().getId());
                    response.setStatus(200);
                    String jsonResponse = this.jsonMapper.writeValueAsString(new McpSchema.JSONRPCResponse("2.0", jsonrpcRequest.id(), initResult, null));
                    PrintWriter writer = response.getWriter();
                    writer.write(jsonResponse);
                    writer.flush();
                    return;
                }
                catch (Exception e) {
                    logger.error("Failed to initialize session: {}", (Object)e.getMessage());
                    this.responseError(response, 500, McpError.builder(-32603).message("Failed to initialize session: " + e.getMessage()).build());
                    return;
                }
            }
            String sessionId = request.getHeader("Mcp-Session-Id");
            if (sessionId == null || sessionId.isBlank()) {
                badRequestErrors.add("Session ID required in mcp-session-id header");
            }
            if (!badRequestErrors.isEmpty()) {
                String combinedMessage = String.join((CharSequence)"; ", badRequestErrors);
                this.responseError(response, 400, McpError.builder(-32601).message(combinedMessage).build());
                return;
            }
            McpStreamableServerSession session = this.sessions.get(sessionId);
            if (session == null) {
                this.responseError(response, 404, McpError.builder(-32603).message("Session not found: " + sessionId).build());
                return;
            }
            if (message instanceof McpSchema.JSONRPCResponse) {
                McpSchema.JSONRPCResponse jsonrpcResponse = (McpSchema.JSONRPCResponse)message;
                session.accept(jsonrpcResponse).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext)).block();
                response.setStatus(202);
            } else if (message instanceof McpSchema.JSONRPCNotification) {
                McpSchema.JSONRPCNotification jsonrpcNotification = (McpSchema.JSONRPCNotification)message;
                session.accept(jsonrpcNotification).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext)).block();
                response.setStatus(202);
            } else if (message instanceof McpSchema.JSONRPCRequest) {
                McpSchema.JSONRPCRequest jsonrpcRequest2 = (McpSchema.JSONRPCRequest)message;
                response.setContentType(TEXT_EVENT_STREAM);
                response.setCharacterEncoding(UTF_8);
                response.setHeader("Cache-Control", "no-cache");
                response.setHeader("Connection", "keep-alive");
                response.setHeader("Access-Control-Allow-Origin", "*");
                AsyncContext asyncContext = request.startAsync();
                asyncContext.setTimeout(0L);
                HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport(sessionId, asyncContext, response.getWriter());
                try {
                    session.responseStream(jsonrpcRequest2, sessionTransport).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext)).block();
                }
                catch (Exception e) {
                    logger.error("Failed to handle request stream: {}", (Object)e.getMessage());
                    asyncContext.complete();
                }
            } else {
                this.responseError(response, 500, McpError.builder(-32600).message("Unknown message type").build());
            }
        }
        catch (IOException | IllegalArgumentException e) {
            logger.error("Failed to deserialize message: {}", (Object)e.getMessage());
            this.responseError(response, 400, McpError.builder(-32600).message("Invalid message format: " + e.getMessage()).build());
        }
        catch (Exception e) {
            logger.error("Error handling message: {}", (Object)e.getMessage());
            try {
                this.responseError(response, 500, McpError.builder(-32603).message("Error processing message: " + e.getMessage()).build());
            }
            catch (IOException ex) {
                logger.error(FAILED_TO_SEND_ERROR_RESPONSE, (Object)ex.getMessage());
                response.sendError(500, "Error processing message");
            }
        }
    }

    protected void doDelete(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        String requestURI = request.getRequestURI();
        if (!requestURI.endsWith(this.mcpEndpoint)) {
            response.sendError(404);
            return;
        }
        if (this.isClosing) {
            response.sendError(503, "Server is shutting down");
            return;
        }
        try {
            Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
            this.securityValidator.validateHeaders(headers);
        }
        catch (ServerTransportSecurityException e) {
            response.sendError(e.getStatusCode(), e.getMessage());
            return;
        }
        if (this.disallowDelete) {
            response.sendError(405);
            return;
        }
        McpTransportContext transportContext = this.contextExtractor.extract(request);
        if (request.getHeader("Mcp-Session-Id") == null) {
            this.responseError(response, 400, McpError.builder(-32601).message("Session ID required in mcp-session-id header").build());
            return;
        }
        String sessionId = request.getHeader("Mcp-Session-Id");
        McpStreamableServerSession session = this.sessions.get(sessionId);
        if (session == null) {
            response.sendError(404);
            return;
        }
        try {
            session.delete().contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext)).block();
            this.sessions.remove(sessionId);
            response.setStatus(200);
        }
        catch (Exception e) {
            logger.error("Failed to delete session {}: {}", (Object)sessionId, (Object)e.getMessage());
            try {
                this.responseError(response, 500, McpError.builder(-32603).message(e.getMessage()).build());
            }
            catch (IOException ex) {
                logger.error(FAILED_TO_SEND_ERROR_RESPONSE, (Object)ex.getMessage());
                response.sendError(500, "Error deleting session");
            }
        }
    }

    public void responseError(HttpServletResponse response, int httpCode, McpError mcpError) throws IOException {
        response.setContentType(APPLICATION_JSON);
        response.setCharacterEncoding(UTF_8);
        response.setStatus(httpCode);
        String jsonError = this.jsonMapper.writeValueAsString(mcpError);
        PrintWriter writer = response.getWriter();
        writer.write(jsonError);
        writer.flush();
    }

    private void sendEvent(PrintWriter writer, String eventType, String data, String id) throws IOException {
        if (id != null) {
            writer.write("id: " + id + "\n");
        }
        writer.write("event: " + eventType + "\n");
        writer.write("data: " + data + "\n\n");
        writer.flush();
        if (writer.checkError()) {
            throw new IOException("Client disconnected");
        }
    }

    public void destroy() {
        this.closeGracefully().block();
        super.destroy();
    }

    public static Builder builder() {
        return new Builder();
    }

    private class HttpServletStreamableMcpSessionTransport
    implements McpStreamableServerTransport {
        private final String sessionId;
        private final AsyncContext asyncContext;
        private final PrintWriter writer;
        private volatile boolean closed = false;
        private final ReentrantLock lock = new ReentrantLock();

        HttpServletStreamableMcpSessionTransport(String sessionId, AsyncContext asyncContext, PrintWriter writer) {
            this.sessionId = sessionId;
            this.asyncContext = asyncContext;
            this.writer = writer;
            logger.debug("Streamable session transport {} initialized with SSE writer", (Object)sessionId);
        }

        @Override
        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            return this.sendMessage(message, null);
        }

        @Override
        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message, String messageId) {
            return Mono.fromRunnable(() -> {
                if (this.closed) {
                    logger.debug("Attempted to send message to closed session: {}", (Object)this.sessionId);
                    return;
                }
                this.lock.lock();
                try {
                    if (this.closed) {
                        logger.debug("Session {} was closed during message send attempt", (Object)this.sessionId);
                        return;
                    }
                    String jsonText = HttpServletStreamableServerTransportProvider.this.jsonMapper.writeValueAsString(message);
                    HttpServletStreamableServerTransportProvider.this.sendEvent(this.writer, HttpServletStreamableServerTransportProvider.MESSAGE_EVENT_TYPE, jsonText, messageId != null ? messageId : this.sessionId);
                    logger.debug("Message sent to session {} with ID {}", (Object)this.sessionId, (Object)messageId);
                }
                catch (Exception e) {
                    logger.error("Failed to send message to session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
                    HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId);
                    this.asyncContext.complete();
                }
                finally {
                    this.lock.unlock();
                }
            });
        }

        @Override
        public <T> T unmarshalFrom(Object data, TypeRef<T> typeRef) {
            return HttpServletStreamableServerTransportProvider.this.jsonMapper.convertValue(data, typeRef);
        }

        @Override
        public Mono<Void> closeGracefully() {
            return Mono.fromRunnable(() -> this.close());
        }

        @Override
        public void close() {
            this.lock.lock();
            try {
                if (this.closed) {
                    logger.debug("Session transport {} already closed", (Object)this.sessionId);
                    return;
                }
                this.closed = true;
                this.asyncContext.complete();
                logger.debug("Successfully completed async context for session {}", (Object)this.sessionId);
            }
            catch (Exception e) {
                logger.warn("Failed to complete async context for session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
            }
            finally {
                this.lock.unlock();
            }
        }
    }

    public static class Builder {
        private McpJsonMapper jsonMapper;
        private String mcpEndpoint = "/mcp";
        private boolean disallowDelete = false;
        private McpTransportContextExtractor<HttpServletRequest> contextExtractor = serverRequest -> McpTransportContext.EMPTY;
        private Duration keepAliveInterval;
        private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;

        public Builder jsonMapper(McpJsonMapper jsonMapper) {
            Assert.notNull(jsonMapper, "JsonMapper must not be null");
            this.jsonMapper = jsonMapper;
            return this;
        }

        public Builder mcpEndpoint(String mcpEndpoint) {
            Assert.notNull(mcpEndpoint, "MCP endpoint must not be null");
            this.mcpEndpoint = mcpEndpoint;
            return this;
        }

        public Builder disallowDelete(boolean disallowDelete) {
            this.disallowDelete = disallowDelete;
            return this;
        }

        public Builder contextExtractor(McpTransportContextExtractor<HttpServletRequest> contextExtractor) {
            Assert.notNull(contextExtractor, "Context extractor must not be null");
            this.contextExtractor = contextExtractor;
            return this;
        }

        public Builder keepAliveInterval(Duration keepAliveInterval) {
            this.keepAliveInterval = keepAliveInterval;
            return this;
        }

        public Builder securityValidator(ServerTransportSecurityValidator securityValidator) {
            Assert.notNull(securityValidator, "Security validator must not be null");
            this.securityValidator = securityValidator;
            return this;
        }

        public HttpServletStreamableServerTransportProvider build() {
            Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set");
            return new HttpServletStreamableServerTransportProvider(this.jsonMapper == null ? McpJsonDefaults.getMapper() : this.jsonMapper, this.mcpEndpoint, this.disallowDelete, this.contextExtractor, this.keepAliveInterval, this.securityValidator);
        }
    }
}

