diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index 218669f7..c54cf289 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -1,6 +1,6 @@ import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions, StartSSEOptions } from "./streamableHttp.js"; import { OAuthClientProvider, UnauthorizedError } from "./auth.js"; -import { JSONRPCMessage } from "../types.js"; +import { JSONRPCMessage, JSONRPCRequest } from "../types.js"; import { InvalidClientError, InvalidGrantError, UnauthorizedClientError } from "../server/auth/errors.js"; @@ -594,6 +594,111 @@ describe("StreamableHTTPClientTransport", () => { await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); }); + + describe('Reconnection Logic', () => { + let transport: StreamableHTTPClientTransport; + + // Use fake timers to control setTimeout and make the test instant. + beforeEach(() => jest.useFakeTimers()); + afterEach(() => jest.useRealTimers()); + + it('should reconnect a GET-initiated notification stream that fails', async () => { + // ARRANGE + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + reconnectionOptions: { + initialReconnectionDelay: 10, + maxRetries: 1, + maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely + reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity + } + }); + + const errorSpy = jest.fn(); + transport.onerror = errorSpy; + + const failingStream = new ReadableStream({ + start(controller) { controller.error(new Error("Network failure")); } + }); + + const fetchMock = global.fetch as jest.Mock; + // Mock the initial GET request, which will fail. + fetchMock.mockResolvedValueOnce({ + ok: true, status: 200, + headers: new Headers({ "content-type": "text/event-stream" }), + body: failingStream, + }); + // Mock the reconnection GET request, which will succeed. + fetchMock.mockResolvedValueOnce({ + ok: true, status: 200, + headers: new Headers({ "content-type": "text/event-stream" }), + body: new ReadableStream(), + }); + + // ACT + await transport.start(); + // Trigger the GET stream directly using the internal method for a clean test. + await transport["_startOrAuthSse"]({}); + await jest.advanceTimersByTimeAsync(20); // Trigger reconnection timeout + + // ASSERT + expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({ + message: expect.stringContaining('SSE stream disconnected: Error: Network failure'), + })); + // THE KEY ASSERTION: A second fetch call proves reconnection was attempted. + expect(fetchMock).toHaveBeenCalledTimes(2); + expect(fetchMock.mock.calls[0][1]?.method).toBe('GET'); + expect(fetchMock.mock.calls[1][1]?.method).toBe('GET'); + }); + + it('should NOT reconnect a POST-initiated stream that fails', async () => { + // ARRANGE + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + reconnectionOptions: { + initialReconnectionDelay: 10, + maxRetries: 1, + maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely + reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity + } + }); + + const errorSpy = jest.fn(); + transport.onerror = errorSpy; + + const failingStream = new ReadableStream({ + start(controller) { controller.error(new Error("Network failure")); } + }); + + const fetchMock = global.fetch as jest.Mock; + // Mock the POST request. It returns a streaming content-type but a failing body. + fetchMock.mockResolvedValueOnce({ + ok: true, status: 200, + headers: new Headers({ "content-type": "text/event-stream" }), + body: failingStream, + }); + + // A dummy request message to trigger the `send` logic. + const requestMessage: JSONRPCRequest = { + jsonrpc: '2.0', + method: 'long_running_tool', + id: 'request-1', + params: {}, + }; + + // ACT + await transport.start(); + // Use the public `send` method to initiate a POST that gets a stream response. + await transport.send(requestMessage); + await jest.advanceTimersByTimeAsync(20); // Advance time to check for reconnections + + // ASSERT + expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({ + message: expect.stringContaining('SSE stream disconnected: Error: Network failure'), + })); + // THE KEY ASSERTION: Fetch was only called ONCE. No reconnection was attempted. + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(fetchMock.mock.calls[0][1]?.method).toBe('POST'); + }); + }); it("invalidates all credentials on InvalidClientError during auth", async () => { const message: JSONRPCMessage = { diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index b81f1a5d..b0894fce 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -231,7 +231,7 @@ const response = await (this._fetch ?? fetch)(this._url, { ); } - this._handleSseStream(response.body, options); + this._handleSseStream(response.body, options, true); } catch (error) { this.onerror?.(error as Error); throw error; @@ -300,7 +300,11 @@ const response = await (this._fetch ?? fetch)(this._url, { }, delay); } - private _handleSseStream(stream: ReadableStream | null, options: StartSSEOptions): void { + private _handleSseStream( + stream: ReadableStream | null, + options: StartSSEOptions, + isReconnectable: boolean, + ): void { if (!stream) { return; } @@ -347,20 +351,22 @@ const response = await (this._fetch ?? fetch)(this._url, { this.onerror?.(new Error(`SSE stream disconnected: ${error}`)); // Attempt to reconnect if the stream disconnects unexpectedly and we aren't closing - if (this._abortController && !this._abortController.signal.aborted) { + if ( + isReconnectable && + this._abortController && + !this._abortController.signal.aborted + ) { // Use the exponential backoff reconnection strategy - if (lastEventId !== undefined) { - try { - this._scheduleReconnection({ - resumptionToken: lastEventId, - onresumptiontoken, - replayMessageId - }, 0); - } - catch (error) { - this.onerror?.(new Error(`Failed to reconnect: ${error instanceof Error ? error.message : String(error)}`)); + try { + this._scheduleReconnection({ + resumptionToken: lastEventId, + onresumptiontoken, + replayMessageId + }, 0); + } + catch (error) { + this.onerror?.(new Error(`Failed to reconnect: ${error instanceof Error ? error.message : String(error)}`)); - } } } } @@ -473,7 +479,7 @@ const response = await (this._fetch ?? fetch)(this._url, init); // Handle SSE stream responses for requests // We use the same handler as standalone streams, which now supports // reconnection with the last event ID - this._handleSseStream(response.body, { onresumptiontoken }); + this._handleSseStream(response.body, { onresumptiontoken }, false); } else if (contentType?.includes("application/json")) { // For non-streaming servers, we might get direct JSON responses const data = await response.json();