1
1
"""Tests for StreamableHTTPSessionManager."""
2
2
3
- from unittest .mock import AsyncMock
3
+ from unittest .mock import AsyncMock , patch
4
4
5
5
import anyio
6
6
import pytest
7
7
8
+ from mcp .server import streamable_http_manager
8
9
from mcp .server .lowlevel import Server
9
10
from mcp .server .streamable_http import MCP_SESSION_ID_HEADER
10
11
from mcp .server .streamable_http_manager import StreamableHTTPSessionManager
@@ -200,49 +201,29 @@ async def mock_receive():
200
201
201
202
202
203
@pytest .mark .anyio
203
- async def test_stateless_requests_cleanup_transport_resources ():
204
- """Test that stateless requests properly clean up transport resources."""
205
- from contextlib import asynccontextmanager
206
- from unittest .mock import patch
207
-
208
- app = Server ("test-stateless-cleanup-server" )
204
+ async def test_stateless_requests_memory_cleanup ():
205
+ """Test that stateless requests actually clean up resources using real transports."""
206
+ app = Server ("test-stateless-real-cleanup" )
209
207
manager = StreamableHTTPSessionManager (app = app , stateless = True )
210
208
211
- # Track created transports and their termination
209
+ # Track created transport instances
212
210
created_transports = []
213
211
214
- # Mock the transport class to track termination
215
- with patch ("mcp.server.streamable_http_manager.StreamableHTTPServerTransport" ) as MockTransport :
216
- # Create a mock transport instance
217
- def create_transport (* args , ** kwargs ):
218
- transport = AsyncMock ()
219
- transport ._terminated = False
220
-
221
- # Track when terminate is called
222
- async def mock_terminate ():
223
- transport ._terminated = True
224
-
225
- transport ._terminate_session = mock_terminate
226
- transport .handle_request = AsyncMock ()
227
-
228
- # Mock the connect context manager
229
- @asynccontextmanager
230
- async def mock_connect ():
231
- yield (AsyncMock (), AsyncMock ())
232
-
233
- transport .connect = mock_connect
212
+ # Patch StreamableHTTPServerTransport constructor to track instances
234
213
235
- created_transports .append (transport )
236
- return transport
214
+ original_constructor = streamable_http_manager .StreamableHTTPServerTransport
237
215
238
- MockTransport .side_effect = create_transport
216
+ def track_transport (* args , ** kwargs ):
217
+ transport = original_constructor (* args , ** kwargs )
218
+ created_transports .append (transport )
219
+ return transport
239
220
221
+ with patch .object (streamable_http_manager , "StreamableHTTPServerTransport" , side_effect = track_transport ):
240
222
async with manager .run ():
241
- # Mock app.run to return quickly
242
- mock_mcp_run = AsyncMock (return_value = None )
243
- app .run = mock_mcp_run
223
+ # Mock app.run to complete immediately
224
+ app .run = AsyncMock (return_value = None )
244
225
245
- # Send a stateless request
226
+ # Send a simple request
246
227
sent_messages = []
247
228
248
229
async def mock_send (message ):
@@ -258,25 +239,27 @@ async def mock_send(message):
258
239
],
259
240
}
260
241
242
+ # Empty body to trigger early return
261
243
async def mock_receive ():
262
244
return {
263
245
"type" : "http.request" ,
264
- "body" : b'{"jsonrpc": "2.0", "method": "test", "id": 1}' ,
246
+ "body" : b"" ,
265
247
"more_body" : False ,
266
248
}
267
249
268
- # Send multiple requests
269
- num_requests = 3
270
- for _ in range (num_requests ):
271
- await manager .handle_request (scope , mock_receive , mock_send )
272
- # Give async tasks time to complete
273
- await anyio .sleep (0.1 )
250
+ # Send a request
251
+ await manager .handle_request (scope , mock_receive , mock_send )
252
+
253
+ # Give async tasks time to complete
254
+ await anyio .sleep (0.1 )
255
+
256
+ # Verify transport was created
257
+ assert len (created_transports ) == 1 , "Should have created one transport"
258
+
259
+ transport = created_transports [0 ]
274
260
275
- # Verify each transport was created
276
- assert len (created_transports ) == num_requests , (
277
- f"Expected { num_requests } transports, got { len (created_transports )} "
278
- )
261
+ # The key assertion - transport should be terminated
262
+ assert transport ._terminated , "Transport should be terminated after stateless request"
279
263
280
- # This is the key assertion - each transport should have been terminated
281
- for i , transport in enumerate (created_transports ):
282
- assert transport ._terminated , f"Transport { i } was not terminated"
264
+ # Verify internal state is cleaned up
265
+ assert len (transport ._request_streams ) == 0 , "Transport should have no active request streams"
0 commit comments