Skip to content

Commit 063bbec

Browse files
committed
feature(auth): DelegatedAuthClientProvider
An optional provider that can be passed to the SSE and StreamableHttp client transports in order to completely delegate the authentication to an external system.
1 parent 2db0dbe commit 063bbec

File tree

5 files changed

+647
-79
lines changed

5 files changed

+647
-79
lines changed

src/client/auth.ts

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,42 @@ export interface OAuthClientProvider {
127127
invalidateCredentials?(scope: 'all' | 'client' | 'tokens' | 'verifier'): void | Promise<void>;
128128
}
129129

130+
/**
131+
* A provider that delegates authentication to an external system.
132+
*
133+
* This interface allows for custom authentication mechanisms that are
134+
* either already implemented on a specific platform or handled outside the
135+
* standard OAuth flow, such as API keys, custom tokens, or integration with external
136+
* authentication services.
137+
*/
138+
export interface DelegatedAuthClientProvider {
139+
/**
140+
* Returns authentication headers to be included in requests.
141+
*
142+
* These headers will be added to all HTTP requests made by the transport.
143+
* Common examples include Authorization headers, API keys, or custom
144+
* authentication tokens.
145+
*
146+
* @returns Headers to include in requests, or undefined if no authentication is available
147+
*/
148+
headers(): HeadersInit | undefined | Promise<HeadersInit | undefined>;
149+
150+
/**
151+
* Performs authentication when a 401 Unauthorized response is received.
152+
*
153+
* This method is called when the server responds with a 401 status code,
154+
* indicating that the current authentication is invalid or expired.
155+
* The implementation should attempt to refresh or re-establish authentication.
156+
*
157+
* @param context Authentication context providing server and resource information
158+
* @param context.serverUrl The URL of the MCP server being authenticated against
159+
* @param context.resourceMetadataUrl Optional URL for resource metadata, if available
160+
* @returns Promise that resolves to true if authentication was successful,
161+
* false if authentication failed
162+
*/
163+
authorize(context: { serverUrl: string | URL; resourceMetadataUrl?: URL }): boolean | Promise<boolean>;
164+
}
165+
130166
export type AuthResult = "AUTHORIZED" | "REDIRECT";
131167

132168
export class UnauthorizedError extends Error {

src/client/sse.test.ts

Lines changed: 219 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { createServer, ServerResponse, type IncomingMessage, type Server } from
22
import { AddressInfo } from "net";
33
import { JSONRPCMessage } from "../types.js";
44
import { SSEClientTransport } from "./sse.js";
5-
import { OAuthClientProvider, UnauthorizedError } from "./auth.js";
5+
import { DelegatedAuthClientProvider, OAuthClientProvider, UnauthorizedError } from "./auth.js";
66
import { OAuthTokens } from "../shared/auth.js";
77
import { InvalidClientError, InvalidGrantError, UnauthorizedClientError } from "../server/auth/errors.js";
88

@@ -1140,11 +1140,11 @@ describe("SSEClientTransport", () => {
11401140

11411141
return {
11421142
get redirectUrl() { return "http://localhost/callback"; },
1143-
get clientMetadata() {
1144-
return {
1143+
get clientMetadata() {
1144+
return {
11451145
redirect_uris: ["http://localhost/callback"],
11461146
client_name: "Test Client"
1147-
};
1147+
};
11481148
},
11491149
clientInformation: jest.fn().mockResolvedValue(clientInfo),
11501150
tokens: jest.fn().mockResolvedValue(tokens),
@@ -1170,7 +1170,7 @@ describe("SSEClientTransport", () => {
11701170
}));
11711171
return;
11721172
}
1173-
1173+
11741174
if (req.url === "/token" && req.method === "POST") {
11751175
// Handle token exchange request
11761176
let body = "";
@@ -1193,7 +1193,7 @@ describe("SSEClientTransport", () => {
11931193
});
11941194
return;
11951195
}
1196-
1196+
11971197
res.writeHead(404).end();
11981198
});
11991199

@@ -1297,14 +1297,14 @@ describe("SSEClientTransport", () => {
12971297

12981298
// Verify custom fetch was used
12991299
expect(customFetch).toHaveBeenCalled();
1300-
1300+
13011301
// Verify specific OAuth endpoints were called with custom fetch
13021302
const customFetchCalls = customFetch.mock.calls;
13031303
const callUrls = customFetchCalls.map(([url]) => url.toString());
1304-
1304+
13051305
// Should have called resource metadata discovery
13061306
expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true);
1307-
1307+
13081308
// Should have called OAuth authorization server metadata discovery
13091309
expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true);
13101310

@@ -1370,19 +1370,19 @@ describe("SSEClientTransport", () => {
13701370

13711371
// Verify custom fetch was used
13721372
expect(customFetch).toHaveBeenCalled();
1373-
1373+
13741374
// Verify specific OAuth endpoints were called with custom fetch
13751375
const customFetchCalls = customFetch.mock.calls;
13761376
const callUrls = customFetchCalls.map(([url]) => url.toString());
1377-
1377+
13781378
// Should have called resource metadata discovery
13791379
expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true);
1380-
1380+
13811381
// Should have called OAuth authorization server metadata discovery
13821382
expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true);
13831383

13841384
// Should have attempted the POST request that triggered the 401
1385-
const postCalls = customFetchCalls.filter(([url, options]) =>
1385+
const postCalls = customFetchCalls.filter(([url, options]) =>
13861386
url.toString() === resourceBaseUrl.href && options?.method === "POST"
13871387
);
13881388
expect(postCalls.length).toBeGreaterThan(0);
@@ -1412,19 +1412,19 @@ describe("SSEClientTransport", () => {
14121412

14131413
// Verify custom fetch was used
14141414
expect(customFetch).toHaveBeenCalled();
1415-
1415+
14161416
// Verify specific OAuth endpoints were called with custom fetch
14171417
const customFetchCalls = customFetch.mock.calls;
14181418
const callUrls = customFetchCalls.map(([url]) => url.toString());
1419-
1419+
14201420
// Should have called resource metadata discovery
14211421
expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true);
1422-
1422+
14231423
// Should have called OAuth authorization server metadata discovery
14241424
expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true);
14251425

14261426
// Should have called token endpoint for authorization code exchange
1427-
const tokenCalls = customFetchCalls.filter(([url, options]) =>
1427+
const tokenCalls = customFetchCalls.filter(([url, options]) =>
14281428
url.toString().includes('/token') && options?.method === "POST"
14291429
);
14301430
expect(tokenCalls.length).toBeGreaterThan(0);
@@ -1441,4 +1441,206 @@ describe("SSEClientTransport", () => {
14411441
expect(globalFetchSpy).not.toHaveBeenCalled();
14421442
});
14431443
});
1444+
1445+
describe("delegated authentication", () => {
1446+
let mockDelegatedAuthProvider: jest.Mocked<DelegatedAuthClientProvider>;
1447+
1448+
beforeEach(() => {
1449+
mockDelegatedAuthProvider = {
1450+
headers: jest.fn(),
1451+
authorize: jest.fn(),
1452+
};
1453+
});
1454+
1455+
it("includes delegated auth headers in requests", async () => {
1456+
mockDelegatedAuthProvider.headers.mockResolvedValue({
1457+
"Authorization": "Bearer delegated-token",
1458+
"X-API-Key": "api-key-123"
1459+
});
1460+
1461+
transport = new SSEClientTransport(resourceBaseUrl, {
1462+
delegatedAuthProvider: mockDelegatedAuthProvider,
1463+
});
1464+
1465+
await transport.start();
1466+
1467+
expect(lastServerRequest.headers.authorization).toBe("Bearer delegated-token");
1468+
expect(lastServerRequest.headers["x-api-key"]).toBe("api-key-123");
1469+
});
1470+
1471+
it("takes precedence over OAuth provider", async () => {
1472+
const mockOAuthProvider = {
1473+
get redirectUrl() { return "http://localhost/callback"; },
1474+
get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; },
1475+
clientInformation: jest.fn(() => ({ client_id: "oauth-client", client_secret: "oauth-secret" })),
1476+
tokens: jest.fn(() => Promise.resolve({ access_token: "oauth-token", token_type: "Bearer" })),
1477+
saveTokens: jest.fn(),
1478+
redirectToAuthorization: jest.fn(),
1479+
saveCodeVerifier: jest.fn(),
1480+
codeVerifier: jest.fn(),
1481+
};
1482+
1483+
mockDelegatedAuthProvider.headers.mockResolvedValue({
1484+
"Authorization": "Bearer delegated-token"
1485+
});
1486+
1487+
transport = new SSEClientTransport(resourceBaseUrl, {
1488+
authProvider: mockOAuthProvider,
1489+
delegatedAuthProvider: mockDelegatedAuthProvider,
1490+
});
1491+
1492+
await transport.start();
1493+
1494+
expect(lastServerRequest.headers.authorization).toBe("Bearer delegated-token");
1495+
expect(mockOAuthProvider.tokens).not.toHaveBeenCalled();
1496+
});
1497+
1498+
it("handles 401 during SSE connection with successful reauth", async () => {
1499+
mockDelegatedAuthProvider.headers.mockResolvedValueOnce(undefined);
1500+
mockDelegatedAuthProvider.authorize.mockResolvedValue(true);
1501+
mockDelegatedAuthProvider.headers.mockResolvedValueOnce({
1502+
"Authorization": "Bearer new-delegated-token"
1503+
});
1504+
1505+
// Create server that returns 401 on first attempt, 200 on second
1506+
resourceServer.close();
1507+
1508+
let attemptCount = 0;
1509+
resourceServer = createServer((req, res) => {
1510+
lastServerRequest = req;
1511+
attemptCount++;
1512+
1513+
if (attemptCount === 1) {
1514+
res.writeHead(401).end();
1515+
return;
1516+
}
1517+
1518+
res.writeHead(200, {
1519+
"Content-Type": "text/event-stream",
1520+
"Cache-Control": "no-cache, no-transform",
1521+
Connection: "keep-alive",
1522+
});
1523+
res.write("event: endpoint\n");
1524+
res.write(`data: ${resourceBaseUrl.href}\n\n`);
1525+
});
1526+
1527+
await new Promise<void>((resolve) => {
1528+
resourceServer.listen(0, "127.0.0.1", () => {
1529+
const addr = resourceServer.address() as AddressInfo;
1530+
resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`);
1531+
resolve();
1532+
});
1533+
});
1534+
1535+
transport = new SSEClientTransport(resourceBaseUrl, {
1536+
delegatedAuthProvider: mockDelegatedAuthProvider,
1537+
});
1538+
1539+
await transport.start();
1540+
1541+
expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledTimes(1);
1542+
expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledWith({
1543+
serverUrl: resourceBaseUrl,
1544+
resourceMetadataUrl: undefined
1545+
});
1546+
expect(attemptCount).toBe(2);
1547+
});
1548+
1549+
it("throws UnauthorizedError when reauth fails", async () => {
1550+
mockDelegatedAuthProvider.headers.mockResolvedValue(undefined);
1551+
mockDelegatedAuthProvider.authorize.mockResolvedValue(false);
1552+
1553+
// Create server that always returns 401
1554+
resourceServer.close();
1555+
1556+
resourceServer = createServer((req, res) => {
1557+
res.writeHead(401).end();
1558+
});
1559+
1560+
await new Promise<void>((resolve) => {
1561+
resourceServer.listen(0, "127.0.0.1", () => {
1562+
const addr = resourceServer.address() as AddressInfo;
1563+
resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`);
1564+
resolve();
1565+
});
1566+
});
1567+
1568+
transport = new SSEClientTransport(resourceBaseUrl, {
1569+
delegatedAuthProvider: mockDelegatedAuthProvider,
1570+
});
1571+
1572+
await expect(transport.start()).rejects.toThrow(UnauthorizedError);
1573+
expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledTimes(1);
1574+
expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledWith({
1575+
serverUrl: resourceBaseUrl,
1576+
resourceMetadataUrl: undefined
1577+
});
1578+
});
1579+
1580+
it("handles 401 during POST request with successful reauth", async () => {
1581+
mockDelegatedAuthProvider.headers.mockResolvedValue({
1582+
"Authorization": "Bearer delegated-token"
1583+
});
1584+
mockDelegatedAuthProvider.authorize.mockResolvedValue(true);
1585+
1586+
// Create server that accepts SSE but returns 401 on first POST, 200 on second
1587+
resourceServer.close();
1588+
1589+
let postAttempts = 0;
1590+
resourceServer = createServer((req, res) => {
1591+
lastServerRequest = req;
1592+
1593+
switch (req.method) {
1594+
case "GET":
1595+
res.writeHead(200, {
1596+
"Content-Type": "text/event-stream",
1597+
"Cache-Control": "no-cache, no-transform",
1598+
Connection: "keep-alive",
1599+
});
1600+
res.write("event: endpoint\n");
1601+
res.write(`data: ${resourceBaseUrl.href}\n\n`);
1602+
break;
1603+
1604+
case "POST":
1605+
postAttempts++;
1606+
if (postAttempts === 1) {
1607+
res.writeHead(401).end();
1608+
} else {
1609+
res.writeHead(200).end();
1610+
}
1611+
break;
1612+
}
1613+
});
1614+
1615+
await new Promise<void>((resolve) => {
1616+
resourceServer.listen(0, "127.0.0.1", () => {
1617+
const addr = resourceServer.address() as AddressInfo;
1618+
resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`);
1619+
resolve();
1620+
});
1621+
});
1622+
1623+
transport = new SSEClientTransport(resourceBaseUrl, {
1624+
delegatedAuthProvider: mockDelegatedAuthProvider,
1625+
});
1626+
1627+
await transport.start();
1628+
1629+
const message: JSONRPCMessage = {
1630+
jsonrpc: "2.0",
1631+
id: "1",
1632+
method: "test",
1633+
params: {},
1634+
};
1635+
1636+
await transport.send(message);
1637+
1638+
expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledTimes(1);
1639+
expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledWith({
1640+
serverUrl: resourceBaseUrl,
1641+
resourceMetadataUrl: undefined
1642+
});
1643+
expect(postAttempts).toBe(2);
1644+
});
1645+
});
14441646
});

0 commit comments

Comments
 (0)