Skip to content

Commit 9e50aa4

Browse files
refactor: Replace mock-based test with real transport cleanup test
The previous test was self-fulfilling as it only verified that a mocked method was called. The new test uses real StreamableHTTPServerTransport instances and verifies that: - The transport's _terminated flag is set to True - The _request_streams dictionary is empty - The test fails when _terminate_session() is not called This provides a more robust regression test for PR #1116's memory leak fix. Github-Issue:#1116
1 parent 6804566 commit 9e50aa4

File tree

1 file changed

+32
-49
lines changed

1 file changed

+32
-49
lines changed

tests/server/test_streamable_http_manager.py

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Tests for StreamableHTTPSessionManager."""
22

3-
from unittest.mock import AsyncMock
3+
from unittest.mock import AsyncMock, patch
44

55
import anyio
66
import pytest
77

8+
from mcp.server import streamable_http_manager
89
from mcp.server.lowlevel import Server
910
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER
1011
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
@@ -200,49 +201,29 @@ async def mock_receive():
200201

201202

202203
@pytest.mark.anyio
203-
async def test_stateless_requests_cleanup_transport_resources():
204-
"""Test that stateless requests properly clean up transport resources."""
205-
from contextlib import asynccontextmanager
206-
from unittest.mock import patch
207-
208-
app = Server("test-stateless-cleanup-server")
204+
async def test_stateless_requests_memory_cleanup():
205+
"""Test that stateless requests actually clean up resources using real transports."""
206+
app = Server("test-stateless-real-cleanup")
209207
manager = StreamableHTTPSessionManager(app=app, stateless=True)
210208

211-
# Track created transports and their termination
209+
# Track created transport instances
212210
created_transports = []
213211

214-
# Mock the transport class to track termination
215-
with patch("mcp.server.streamable_http_manager.StreamableHTTPServerTransport") as MockTransport:
216-
# Create a mock transport instance
217-
def create_transport(*args, **kwargs):
218-
transport = AsyncMock()
219-
transport._terminated = False
220-
221-
# Track when terminate is called
222-
async def mock_terminate():
223-
transport._terminated = True
224-
225-
transport._terminate_session = mock_terminate
226-
transport.handle_request = AsyncMock()
227-
228-
# Mock the connect context manager
229-
@asynccontextmanager
230-
async def mock_connect():
231-
yield (AsyncMock(), AsyncMock())
232-
233-
transport.connect = mock_connect
212+
# Patch StreamableHTTPServerTransport constructor to track instances
234213

235-
created_transports.append(transport)
236-
return transport
214+
original_constructor = streamable_http_manager.StreamableHTTPServerTransport
237215

238-
MockTransport.side_effect = create_transport
216+
def track_transport(*args, **kwargs):
217+
transport = original_constructor(*args, **kwargs)
218+
created_transports.append(transport)
219+
return transport
239220

221+
with patch.object(streamable_http_manager, "StreamableHTTPServerTransport", side_effect=track_transport):
240222
async with manager.run():
241-
# Mock app.run to return quickly
242-
mock_mcp_run = AsyncMock(return_value=None)
243-
app.run = mock_mcp_run
223+
# Mock app.run to complete immediately
224+
app.run = AsyncMock(return_value=None)
244225

245-
# Send a stateless request
226+
# Send a simple request
246227
sent_messages = []
247228

248229
async def mock_send(message):
@@ -258,25 +239,27 @@ async def mock_send(message):
258239
],
259240
}
260241

242+
# Empty body to trigger early return
261243
async def mock_receive():
262244
return {
263245
"type": "http.request",
264-
"body": b'{"jsonrpc": "2.0", "method": "test", "id": 1}',
246+
"body": b"",
265247
"more_body": False,
266248
}
267249

268-
# Send multiple requests
269-
num_requests = 3
270-
for _ in range(num_requests):
271-
await manager.handle_request(scope, mock_receive, mock_send)
272-
# Give async tasks time to complete
273-
await anyio.sleep(0.1)
250+
# Send a request
251+
await manager.handle_request(scope, mock_receive, mock_send)
252+
253+
# Give async tasks time to complete
254+
await anyio.sleep(0.1)
255+
256+
# Verify transport was created
257+
assert len(created_transports) == 1, "Should have created one transport"
258+
259+
transport = created_transports[0]
274260

275-
# Verify each transport was created
276-
assert len(created_transports) == num_requests, (
277-
f"Expected {num_requests} transports, got {len(created_transports)}"
278-
)
261+
# The key assertion - transport should be terminated
262+
assert transport._terminated, "Transport should be terminated after stateless request"
279263

280-
# This is the key assertion - each transport should have been terminated
281-
for i, transport in enumerate(created_transports):
282-
assert transport._terminated, f"Transport {i} was not terminated"
264+
# Verify internal state is cleaned up
265+
assert len(transport._request_streams) == 0, "Transport should have no active request streams"

0 commit comments

Comments
 (0)