diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 63a990e8..792583c6 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -339,3 +339,16 @@ required-features = [ ] path = "tests/test_streamable_http_stale_session.rs" +[[test]] +name = "test_streamable_http_connection_reuse" +required-features = [ + "server", + "client", + "macros", + "schemars", + "transport-streamable-http-server", + "transport-streamable-http-client", + "transport-streamable-http-client-reqwest", +] +path = "tests/test_streamable_http_connection_reuse.rs" + diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index dea98c7b..30312edc 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -262,7 +262,7 @@ impl StreamableHttpClientTransport { /// This method requires the `transport-streamable-http-client-reqwest` feature. pub fn from_uri(uri: impl Into>) -> Self { StreamableHttpClientTransport::with_client( - reqwest::Client::default(), + Self::default_http_client(), StreamableHttpClientTransportConfig { uri: uri.into(), auth_header: None, @@ -277,7 +277,19 @@ impl StreamableHttpClientTransport { /// /// * `config` - The config to use with this transport pub fn from_config(config: StreamableHttpClientTransportConfig) -> Self { - StreamableHttpClientTransport::with_client(reqwest::Client::default(), config) + StreamableHttpClientTransport::with_client(Self::default_http_client(), config) + } + + /// Build the default reqwest client for this transport. + /// + /// Disables idle connection pooling to avoid ~40 ms stalls caused by + /// TCP Delayed ACK on Linux when the previous response body was not + /// fully consumed before the pool attempts to reuse the connection. + fn default_http_client() -> reqwest::Client { + reqwest::Client::builder() + .pool_max_idle_per_host(0) + .build() + .expect("failed to build default reqwest client") } } diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 980e63db..dc3915dc 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -281,6 +281,37 @@ impl StreamableHttpClientWorker { } impl StreamableHttpClientWorker { + /// Convert a raw SSE stream into a JSON-RPC message stream without + /// reconnection logic. + fn raw_sse_to_jsonrpc( + stream: BoxedSseStream, + ) -> impl Stream>> + Send + 'static + { + stream.filter_map(|event| async { + match event { + Err(e) => Some(Err(StreamableHttpError::Sse(e))), + Ok(sse) => { + let is_message = + matches!(sse.event.as_deref(), None | Some("") | Some("message")); + if !is_message { + return None; + } + let data = sse.data?; + if data.trim().is_empty() { + return None; + } + match serde_json::from_str::(&data) { + Ok(msg) => Some(Ok(msg)), + Err(e) => { + tracing::debug!("failed to deserialize server message: {e}"); + None + } + } + } + } + }) + } + async fn execute_sse_stream( sse_stream: impl Stream>> + Send @@ -303,14 +334,23 @@ impl StreamableHttpClientWorker { let Some(message) = message.transpose()? else { break; }; - let is_response = matches!(message, ServerJsonRpcMessage::Response(_)); + let is_response = matches!( + message, + ServerJsonRpcMessage::Response(_) | ServerJsonRpcMessage::Error(_) + ); let yield_result = sse_worker_tx.send(message).await; if yield_result.is_err() { tracing::trace!("streamable http transport worker dropped, exiting"); break; } if close_on_response && is_response { - tracing::debug!("got response, closing sse stream"); + tracing::debug!("got response, draining sse stream for connection reuse"); + // Consume the remaining stream so the HTTP/1.1 connection + // returns to the pool cleanly. + let _ = tokio::time::timeout(std::time::Duration::from_millis(50), async { + while sse_stream.next().await.is_some() {} + }) + .await; break; } } @@ -718,38 +758,12 @@ impl Worker for StreamableHttpClientWorker { Ok(()) } Ok(StreamableHttpPostResponse::Sse(stream, ..)) => { - if let Some(sid) = &session_id { - let sse_stream = SseAutoReconnectStream::new( - stream, - StreamableHttpClientReconnect { - client: self.client.clone(), - session_id: sid.clone(), - uri: config.uri.clone(), - auth_header: config.auth_header.clone(), - custom_headers: protocol_headers - .clone(), - }, - self.config.retry_config.clone(), - ); - streams.spawn(Self::execute_sse_stream( - sse_stream, - sse_worker_tx.clone(), - true, - transport_task_ct.child_token(), - )); - } else { - let sse_stream = - SseAutoReconnectStream::never_reconnect( - stream, - StreamableHttpError::::UnexpectedEndOfStream, - ); - streams.spawn(Self::execute_sse_stream( - sse_stream, - sse_worker_tx.clone(), - true, - transport_task_ct.child_token(), - )); - } + streams.spawn(Self::execute_sse_stream( + Self::raw_sse_to_jsonrpc(stream), + sse_worker_tx.clone(), + true, + transport_task_ct.child_token(), + )); tracing::trace!("got new sse stream after re-init"); Ok(()) } @@ -769,36 +783,12 @@ impl Worker for StreamableHttpClientWorker { Ok(()) } Ok(StreamableHttpPostResponse::Sse(stream, ..)) => { - if let Some(session_id) = &session_id { - let sse_stream = SseAutoReconnectStream::new( - stream, - StreamableHttpClientReconnect { - client: self.client.clone(), - session_id: session_id.clone(), - uri: config.uri.clone(), - auth_header: config.auth_header.clone(), - custom_headers: protocol_headers.clone(), - }, - self.config.retry_config.clone(), - ); - streams.spawn(Self::execute_sse_stream( - sse_stream, - sse_worker_tx.clone(), - true, - transport_task_ct.child_token(), - )); - } else { - let sse_stream = SseAutoReconnectStream::never_reconnect( - stream, - StreamableHttpError::::UnexpectedEndOfStream, - ); - streams.spawn(Self::execute_sse_stream( - sse_stream, - sse_worker_tx.clone(), - true, - transport_task_ct.child_token(), - )); - } + streams.spawn(Self::execute_sse_stream( + Self::raw_sse_to_jsonrpc(stream), + sse_worker_tx.clone(), + true, + transport_task_ct.child_token(), + )); tracing::trace!("got new sse stream"); Ok(()) } diff --git a/crates/rmcp/src/transport/streamable_http_server/session/local.rs b/crates/rmcp/src/transport/streamable_http_server/session/local.rs index dcaf204c..814f317d 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/local.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/local.rs @@ -470,7 +470,7 @@ impl LocalSessionWorker { { OutboundChannel::RequestWise { id: *id, - close: false, + close: true, } } else { OutboundChannel::Common @@ -483,7 +483,7 @@ impl LocalSessionWorker { { OutboundChannel::RequestWise { id: *id, - close: false, + close: true, } } else { OutboundChannel::Common @@ -501,7 +501,11 @@ impl LocalSessionWorker { if let Some(request_wise) = self.tx_router.get_mut(&id) { request_wise.tx.send(message).await; if close { - self.tx_router.remove(&id); + if let Some(channel) = self.tx_router.remove(&id) { + for resource in channel.resources { + self.resource_router.remove(&resource); + } + } } } else { return Err(SessionError::ChannelClosed(Some(id))); diff --git a/crates/rmcp/tests/test_streamable_http_connection_reuse.rs b/crates/rmcp/tests/test_streamable_http_connection_reuse.rs new file mode 100644 index 00000000..553448ea --- /dev/null +++ b/crates/rmcp/tests/test_streamable_http_connection_reuse.rs @@ -0,0 +1,122 @@ +#![cfg(not(feature = "local"))] + +use std::time::Instant; + +use rmcp::{ + ServerHandler, ServiceExt, + handler::server::{router::tool::ToolRouter, wrapper::Parameters}, + model::{CallToolRequestParams, ClientInfo, ServerCapabilities, ServerInfo}, + schemars, tool, tool_handler, tool_router, + transport::{ + StreamableHttpClientTransport, + streamable_http_client::StreamableHttpClientTransportConfig, + streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, + }, + }, +}; +use tokio_util::sync::CancellationToken; + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct SumRequest { + a: i32, + b: i32, +} + +#[derive(Debug, Clone)] +struct SumServer { + tool_router: ToolRouter, +} + +impl SumServer { + fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } +} + +#[tool_router] +impl SumServer { + #[tool(description = "Sum two numbers")] + fn sum(&self, Parameters(SumRequest { a, b }): Parameters) -> String { + (a + b).to_string() + } +} + +#[tool_handler(router = self.tool_router)] +impl ServerHandler for SumServer { + fn get_info(&self) -> ServerInfo { + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + } +} + +/// Verify that subsequent tool calls do not regress in latency due to +/// HTTP/1.1 connection pool exhaustion. Before the fix, each POST SSE +/// response was dropped without fully consuming the body, preventing +/// connection reuse and forcing a new TCP connection (~40 ms) per call. +#[tokio::test] +async fn test_subsequent_tool_calls_reuse_connections() -> anyhow::Result<()> { + let ct = CancellationToken::new(); + + let service: StreamableHttpService = StreamableHttpService::new( + || Ok(SumServer::new()), + Default::default(), + StreamableHttpServerConfig::default() + .with_sse_keep_alive(None) + .with_cancellation_token(ct.child_token()), + ); + + let router = axum::Router::new().nest_service("/mcp", service); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let server_handle = tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + + let transport = StreamableHttpClientTransport::from_config( + StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")), + ); + let client = ClientInfo::default().serve(transport).await?; + + // Warm up: first call may include one-time setup costs. + let args: serde_json::Map = + serde_json::from_value(serde_json::json!({"a": 1, "b": 2}))?; + let _ = client + .call_tool(CallToolRequestParams::new("sum").with_arguments(args)) + .await?; + + // Measure subsequent calls. + let mut durations = Vec::new(); + for i in 0..5i32 { + let args: serde_json::Map = + serde_json::from_value(serde_json::json!({"a": i, "b": i + 1}))?; + let start = Instant::now(); + let result = client + .call_tool(CallToolRequestParams::new("sum").with_arguments(args)) + .await?; + let elapsed = start.elapsed(); + durations.push(elapsed); + + assert!(result.is_error != Some(true)); + } + + let _ = client.cancel().await; + ct.cancel(); + server_handle.await?; + + // With connection reuse, localhost calls should complete well under 20 ms. + // Before the fix, they consistently took ~42 ms due to new TCP connections. + let max_allowed = std::time::Duration::from_millis(20); + for d in &durations { + assert!(*d < max_allowed); + } + + Ok(()) +}