/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.server.transport;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.json.McpJsonDefaults;
import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.server.McpStatelessServerHandler;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.server.transport.HttpServletRequestUtils;
import io.modelcontextprotocol.server.transport.ServerTransportSecurityException;
import io.modelcontextprotocol.server.transport.ServerTransportSecurityValidator;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpStatelessServerTransport;
import io.modelcontextprotocol.util.Assert;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.WebServlet;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;

@WebServlet(asyncSupported=true)
public class HttpServletStatelessServerTransport
extends HttpServlet
implements McpStatelessServerTransport {
    private static final Logger logger = LoggerFactory.getLogger(HttpServletStatelessServerTransport.class);
    public static final String UTF_8 = "UTF-8";
    public static final String APPLICATION_JSON = "application/json";
    public static final String TEXT_EVENT_STREAM = "text/event-stream";
    public static final String ACCEPT = "Accept";
    public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}";
    private final McpJsonMapper jsonMapper;
    private final String mcpEndpoint;
    private McpStatelessServerHandler mcpHandler;
    private McpTransportContextExtractor<HttpServletRequest> contextExtractor;
    private volatile boolean isClosing = false;
    private final ServerTransportSecurityValidator securityValidator;

    private HttpServletStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, McpTransportContextExtractor<HttpServletRequest> contextExtractor, ServerTransportSecurityValidator securityValidator) {
        Assert.notNull(jsonMapper, "jsonMapper must not be null");
        Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null");
        Assert.notNull(contextExtractor, "contextExtractor must not be null");
        Assert.notNull(securityValidator, "Security validator must not be null");
        this.jsonMapper = jsonMapper;
        this.mcpEndpoint = mcpEndpoint;
        this.contextExtractor = contextExtractor;
        this.securityValidator = securityValidator;
    }

    @Override
    public void setMcpHandler(McpStatelessServerHandler mcpHandler) {
        this.mcpHandler = mcpHandler;
    }

    @Override
    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(() -> {
            this.isClosing = true;
        });
    }

    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        String requestURI = request.getRequestURI();
        if (!requestURI.endsWith(this.mcpEndpoint)) {
            response.sendError(404);
            return;
        }
        response.sendError(405);
    }

    protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        String requestURI = request.getRequestURI();
        if (!requestURI.endsWith(this.mcpEndpoint)) {
            response.sendError(404);
            return;
        }
        if (this.isClosing) {
            response.sendError(503, "Server is shutting down");
            return;
        }
        try {
            Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
            this.securityValidator.validateHeaders(headers);
        }
        catch (ServerTransportSecurityException e) {
            response.sendError(e.getStatusCode(), e.getMessage());
            return;
        }
        McpTransportContext transportContext = this.contextExtractor.extract(request);
        String accept = request.getHeader(ACCEPT);
        if (accept == null || !accept.contains(APPLICATION_JSON) || !accept.contains(TEXT_EVENT_STREAM)) {
            this.responseError(response, 400, new McpError((Object)"Both application/json and text/event-stream required in Accept header"));
            return;
        }
        try {
            String line;
            BufferedReader reader = request.getReader();
            StringBuilder body = new StringBuilder();
            while ((line = reader.readLine()) != null) {
                body.append(line);
            }
            McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, body.toString());
            if (message instanceof McpSchema.JSONRPCRequest) {
                McpSchema.JSONRPCRequest jsonrpcRequest = (McpSchema.JSONRPCRequest)message;
                try {
                    McpSchema.JSONRPCResponse jsonrpcResponse = this.mcpHandler.handleRequest(transportContext, jsonrpcRequest).contextWrite(ctx -> ctx.put("MCP_TRANSPORT_CONTEXT", transportContext)).block();
                    response.setContentType(APPLICATION_JSON);
                    response.setCharacterEncoding(UTF_8);
                    response.setStatus(200);
                    String jsonResponseText = this.jsonMapper.writeValueAsString(jsonrpcResponse);
                    PrintWriter writer = response.getWriter();
                    writer.write(jsonResponseText);
                    writer.flush();
                }
                catch (Exception e) {
                    logger.error("Failed to handle request: {}", (Object)e.getMessage());
                    this.responseError(response, 500, new McpError((Object)("Failed to handle request: " + e.getMessage())));
                }
            } else if (message instanceof McpSchema.JSONRPCNotification) {
                McpSchema.JSONRPCNotification jsonrpcNotification = (McpSchema.JSONRPCNotification)message;
                try {
                    this.mcpHandler.handleNotification(transportContext, jsonrpcNotification).contextWrite(ctx -> ctx.put("MCP_TRANSPORT_CONTEXT", transportContext)).block();
                    response.setStatus(202);
                }
                catch (Exception e) {
                    logger.error("Failed to handle notification: {}", (Object)e.getMessage());
                    this.responseError(response, 500, new McpError((Object)("Failed to handle notification: " + e.getMessage())));
                }
            } else {
                this.responseError(response, 400, new McpError((Object)"The server accepts either requests or notifications"));
            }
        }
        catch (IOException | IllegalArgumentException e) {
            logger.error("Failed to deserialize message: {}", (Object)e.getMessage());
            this.responseError(response, 400, new McpError((Object)"Invalid message format"));
        }
        catch (Exception e) {
            logger.error("Unexpected error handling message: {}", (Object)e.getMessage());
            this.responseError(response, 500, new McpError((Object)("Unexpected error: " + e.getMessage())));
        }
    }

    private void responseError(HttpServletResponse response, int httpCode, McpError mcpError) throws IOException {
        response.setContentType(APPLICATION_JSON);
        response.setCharacterEncoding(UTF_8);
        response.setStatus(httpCode);
        String jsonError = this.jsonMapper.writeValueAsString(mcpError);
        PrintWriter writer = response.getWriter();
        writer.write(jsonError);
        writer.flush();
    }

    public void destroy() {
        this.closeGracefully().block();
        super.destroy();
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private McpJsonMapper jsonMapper;
        private String mcpEndpoint = "/mcp";
        private McpTransportContextExtractor<HttpServletRequest> contextExtractor = serverRequest -> McpTransportContext.EMPTY;
        private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;

        private Builder() {
        }

        public Builder jsonMapper(McpJsonMapper jsonMapper) {
            Assert.notNull(jsonMapper, "JsonMapper must not be null");
            this.jsonMapper = jsonMapper;
            return this;
        }

        public Builder messageEndpoint(String messageEndpoint) {
            Assert.notNull(messageEndpoint, "Message endpoint must not be null");
            this.mcpEndpoint = messageEndpoint;
            return this;
        }

        public Builder contextExtractor(McpTransportContextExtractor<HttpServletRequest> contextExtractor) {
            Assert.notNull(contextExtractor, "Context extractor must not be null");
            this.contextExtractor = contextExtractor;
            return this;
        }

        public Builder securityValidator(ServerTransportSecurityValidator securityValidator) {
            Assert.notNull(securityValidator, "Security validator must not be null");
            this.securityValidator = securityValidator;
            return this;
        }

        public HttpServletStatelessServerTransport build() {
            Assert.notNull(this.mcpEndpoint, "Message endpoint must be set");
            return new HttpServletStatelessServerTransport(this.jsonMapper == null ? McpJsonDefaults.getMapper() : this.jsonMapper, this.mcpEndpoint, this.contextExtractor, this.securityValidator);
        }
    }
}

