diff --git a/README.md b/README.md index d76d3d267..c2ff39f33 100644 --- a/README.md +++ b/README.md @@ -814,7 +814,7 @@ async def main(): The SDK includes [authorization support](https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization) for connecting to protected MCP servers: ```python -from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.auth import OAuthClientProvider, ClientCredentialsProvider, TokenStorage from mcp.client.session import ClientSession from mcp.client.streamable_http import streamablehttp_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken @@ -851,6 +851,9 @@ async def main(): callback_handler=lambda: ("auth_code", None), ) + # For machine-to-machine scenarios, use ClientCredentialsProvider + # instead of OAuthClientProvider. + # Use with streamable HTTP client async with streamablehttp_client( "https://api.example.com/mcp", auth=oauth_auth diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index fc6c96a43..ead270e55 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -499,3 +499,207 @@ async def _refresh_access_token(self) -> bool: except Exception: logger.exception("Token refresh failed") return False + + +class ClientCredentialsProvider(httpx.Auth): + """HTTPX auth using the OAuth2 client credentials grant.""" + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + timeout: float = 300.0, + ): + self.server_url = server_url + self.client_metadata = client_metadata + self.storage = storage + self.timeout = timeout + + self._current_tokens: OAuthToken | None = None + self._metadata: OAuthMetadata | None = None + self._client_info: OAuthClientInformationFull | None = None + self._token_expiry_time: float | None = None + + self._token_lock = anyio.Lock() + + def _get_authorization_base_url(self, server_url: str) -> str: + from urllib.parse import urlparse, urlunparse + + parsed = urlparse(server_url) + return urlunparse((parsed.scheme, parsed.netloc, "", "", "", "")) + + async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None: + auth_base_url = self._get_authorization_base_url(server_url) + url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server") + headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION} + + async with httpx.AsyncClient() as client: + try: + response = await client.get(url, headers=headers) + if response.status_code == 404: + return None + response.raise_for_status() + return OAuthMetadata.model_validate(response.json()) + except Exception: + try: + response = await client.get(url) + if response.status_code == 404: + return None + response.raise_for_status() + return OAuthMetadata.model_validate(response.json()) + except Exception: + logger.exception("Failed to discover OAuth metadata") + return None + + async def _register_oauth_client( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + metadata: OAuthMetadata | None = None, + ) -> OAuthClientInformationFull: + if not metadata: + metadata = await self._discover_oauth_metadata(server_url) + + if metadata and metadata.registration_endpoint: + registration_url = str(metadata.registration_endpoint) + else: + auth_base_url = self._get_authorization_base_url(server_url) + registration_url = urljoin(auth_base_url, "/register") + + if ( + client_metadata.scope is None + and metadata + and metadata.scopes_supported is not None + ): + client_metadata.scope = " ".join(metadata.scopes_supported) + + registration_data = client_metadata.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + + async with httpx.AsyncClient() as client: + response = await client.post( + registration_url, + json=registration_data, + headers={"Content-Type": "application/json"}, + ) + + if response.status_code not in (200, 201): + raise httpx.HTTPStatusError( + f"Registration failed: {response.status_code}", + request=response.request, + response=response, + ) + + return OAuthClientInformationFull.model_validate(response.json()) + + def _has_valid_token(self) -> bool: + if not self._current_tokens or not self._current_tokens.access_token: + return False + + if self._token_expiry_time and time.time() > self._token_expiry_time: + return False + return True + + async def _validate_token_scopes(self, token_response: OAuthToken) -> None: + if not token_response.scope: + return + + requested_scopes: set[str] = set() + if self.client_metadata.scope: + requested_scopes = set(self.client_metadata.scope.split()) + returned_scopes = set(token_response.scope.split()) + unauthorized_scopes = returned_scopes - requested_scopes + if unauthorized_scopes: + raise Exception( + f"Server granted unauthorized scopes: {unauthorized_scopes}." + ) + else: + granted = set(token_response.scope.split()) + logger.debug( + "No explicit scopes requested, accepting server-granted scopes: %s", + granted, + ) + + async def initialize(self) -> None: + self._current_tokens = await self.storage.get_tokens() + self._client_info = await self.storage.get_client_info() + + async def _get_or_register_client(self) -> OAuthClientInformationFull: + if not self._client_info: + self._client_info = await self._register_oauth_client( + self.server_url, self.client_metadata, self._metadata + ) + await self.storage.set_client_info(self._client_info) + return self._client_info + + async def _request_token(self) -> None: + if not self._metadata: + self._metadata = await self._discover_oauth_metadata(self.server_url) + + client_info = await self._get_or_register_client() + + if self._metadata and self._metadata.token_endpoint: + token_url = str(self._metadata.token_endpoint) + else: + auth_base_url = self._get_authorization_base_url(self.server_url) + token_url = urljoin(auth_base_url, "/token") + + token_data = { + "grant_type": "client_credentials", + "client_id": client_info.client_id, + } + + if client_info.client_secret: + token_data["client_secret"] = client_info.client_secret + + if self.client_metadata.scope: + token_data["scope"] = self.client_metadata.scope + + async with httpx.AsyncClient() as client: + response = await client.post( + token_url, + data=token_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30.0, + ) + + if response.status_code != 200: + raise Exception( + f"Token request failed: {response.status_code} {response.text}" + ) + + token_response = OAuthToken.model_validate(response.json()) + await self._validate_token_scopes(token_response) + + if token_response.expires_in: + self._token_expiry_time = time.time() + token_response.expires_in + else: + self._token_expiry_time = None + + await self.storage.set_tokens(token_response) + self._current_tokens = token_response + + async def ensure_token(self) -> None: + async with self._token_lock: + if self._has_valid_token(): + return + await self._request_token() + + async def async_auth_flow( + self, request: httpx.Request + ) -> AsyncGenerator[httpx.Request, httpx.Response]: + if not self._has_valid_token(): + await self.initialize() + await self.ensure_token() + + if self._current_tokens and self._current_tokens.access_token: + request.headers["Authorization"] = ( + f"Bearer {self._current_tokens.access_token}" + ) + + response = yield request + + if response.status_code == 401: + self._current_tokens = None diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 94a5c4de3..0005b38a1 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -47,16 +47,25 @@ class RefreshTokenRequest(BaseModel): client_secret: str | None = None +class ClientCredentialsRequest(BaseModel): + """Token request for the client credentials grant.""" + + grant_type: Literal["client_credentials"] + scope: str | None = Field(None, description="Optional scope parameter") + client_id: str + client_secret: str | None = None + + class TokenRequest( RootModel[ Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, + AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest, Field(discriminator="grant_type"), ] ] ): root: Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, + AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest, Field(discriminator="grant_type"), ] @@ -204,6 +213,26 @@ async def handle(self, request: Request): ) ) + case ClientCredentialsRequest(): + scopes = ( + token_request.scope.split(" ") + if token_request.scope + else client_info.scope.split(" ") + if client_info.scope + else [] + ) + try: + tokens = await self.provider.exchange_client_credentials( + client_info, scopes + ) + except TokenError as e: + return self.response( + TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + ) + case RefreshTokenRequest(): refresh_token = await self.provider.load_refresh_token( client_info, token_request.refresh_token diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index be1ac1dbc..86d445086 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -247,6 +247,12 @@ async def exchange_refresh_token( """ ... + async def exchange_client_credentials( + self, client: OAuthClientInformationFull, scopes: list[str] + ) -> OAuthToken: + """Exchange client credentials for an access token.""" + ... + async def load_access_token(self, token: str) -> AccessTokenT | None: """ Loads an access token by its token. diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index d588d78ee..4809029ac 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -164,7 +164,11 @@ def build_metadata( scopes_supported=client_registration_options.valid_scopes, response_types_supported=["code"], response_modes_supported=None, - grant_types_supported=["authorization_code", "refresh_token"], + grant_types_supported=[ + "authorization_code", + "refresh_token", + "client_credentials", + ], token_endpoint_auth_methods_supported=["client_secret_post"], token_endpoint_auth_signing_alg_values_supported=None, service_documentation=service_documentation_url, diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 22f8a971d..90835bb2d 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -39,8 +39,10 @@ class OAuthClientMetadata(BaseModel): token_endpoint_auth_method: Literal["none", "client_secret_post"] = ( "client_secret_post" ) - # grant_types: this implementation only supports authorization_code & refresh_token - grant_types: list[Literal["authorization_code", "refresh_token"]] = [ + # grant_types: support authorization_code, refresh_token, client_credentials + grant_types: list[ + Literal["authorization_code", "refresh_token", "client_credentials"] + ] = [ "authorization_code", "refresh_token", ] @@ -114,7 +116,14 @@ class OAuthMetadata(BaseModel): response_types_supported: list[Literal["code"]] = ["code"] response_modes_supported: list[Literal["query", "fragment"]] | None = None grant_types_supported: ( - list[Literal["authorization_code", "refresh_token"]] | None + list[ + Literal[ + "authorization_code", + "refresh_token", + "client_credentials", + ] + ] + | None ) = None token_endpoint_auth_methods_supported: ( list[Literal["none", "client_secret_post"]] | None diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 2edaff946..f41dddb61 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -13,7 +13,7 @@ from inline_snapshot import snapshot from pydantic import AnyHttpUrl -from mcp.client.auth import OAuthClientProvider +from mcp.client.auth import ClientCredentialsProvider, OAuthClientProvider from mcp.server.auth.routes import build_metadata from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions from mcp.shared.auth import ( @@ -60,6 +60,18 @@ def client_metadata(): ) +@pytest.fixture +def client_credentials_metadata(): + return OAuthClientMetadata( + redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")], + client_name="CC Client", + grant_types=["client_credentials"], + response_types=["code"], + scope="read write", + token_endpoint_auth_method="client_secret_post", + ) + + @pytest.fixture def oauth_metadata(): return OAuthMetadata( @@ -69,7 +81,11 @@ def oauth_metadata(): registration_endpoint=AnyHttpUrl("https://auth.example.com/register"), scopes_supported=["read", "write", "admin"], response_types_supported=["code"], - grant_types_supported=["authorization_code", "refresh_token"], + grant_types_supported=[ + "authorization_code", + "refresh_token", + "client_credentials", + ], code_challenge_methods_supported=["S256"], ) @@ -115,6 +131,14 @@ async def mock_callback_handler() -> tuple[str, str | None]: ) +@pytest.fixture +async def client_credentials_provider(client_credentials_metadata, mock_storage): + return ClientCredentialsProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_credentials_metadata, + storage=mock_storage, + ) + class TestOAuthClientProvider: """Test OAuth client provider functionality.""" @@ -975,7 +999,11 @@ def test_build_metadata( token_endpoint=AnyHttpUrl(token_endpoint), registration_endpoint=AnyHttpUrl(registration_endpoint), scopes_supported=["read", "write", "admin"], - grant_types_supported=["authorization_code", "refresh_token"], + grant_types_supported=[ + "authorization_code", + "refresh_token", + "client_credentials", + ], token_endpoint_auth_methods_supported=["client_secret_post"], service_documentation=AnyHttpUrl(service_documentation_url), revocation_endpoint=AnyHttpUrl(revocation_endpoint), @@ -983,3 +1011,56 @@ def test_build_metadata( code_challenge_methods_supported=["S256"], ) ) + + +class TestClientCredentialsProvider: + @pytest.mark.anyio + async def test_request_token_success( + self, + client_credentials_provider, + oauth_metadata, + oauth_client_info, + oauth_token, + ): + client_credentials_provider._metadata = oauth_metadata + client_credentials_provider._client_info = oauth_client_info + + token_json = oauth_token.model_dump(by_alias=True, mode="json") + token_json.pop("refresh_token", None) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = token_json + mock_client.post.return_value = mock_response + + await client_credentials_provider.ensure_token() + + mock_client.post.assert_called_once() + assert ( + client_credentials_provider._current_tokens.access_token + == oauth_token.access_token + ) + + @pytest.mark.anyio + async def test_async_auth_flow(self, client_credentials_provider, oauth_token): + client_credentials_provider._current_tokens = oauth_token + client_credentials_provider._token_expiry_time = time.time() + 3600 + + request = httpx.Request("GET", "https://api.example.com/data") + mock_response = Mock() + mock_response.status_code = 200 + + auth_flow = client_credentials_provider.async_auth_flow(request) + updated_request = await auth_flow.__anext__() + assert ( + updated_request.headers["Authorization"] + == f"Bearer {oauth_token.access_token}" + ) + try: + await auth_flow.asend(mock_response) + except StopAsyncIteration: + pass diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index d237e860e..a22662045 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -166,6 +166,23 @@ async def exchange_refresh_token( refresh_token=new_refresh_token, ) + async def exchange_client_credentials( + self, client: OAuthClientInformationFull, scopes: list[str] + ) -> OAuthToken: + access_token = f"access_{secrets.token_hex(32)}" + self.tokens[access_token] = AccessToken( + token=access_token, + client_id=client.client_id, + scopes=scopes, + expires_at=int(time.time()) + 3600, + ) + return OAuthToken( + access_token=access_token, + token_type="bearer", + expires_in=3600, + scope=" ".join(scopes), + ) + async def load_access_token(self, token: str) -> AccessToken | None: token_info = self.tokens.get(token) @@ -370,6 +387,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): assert metadata["grant_types_supported"] == [ "authorization_code", "refresh_token", + "client_credentials", ] assert metadata["service_documentation"] == "https://docs.example.com/" @@ -1265,3 +1283,25 @@ async def test_authorize_invalid_scope( # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state" + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [{"grant_types": ["client_credentials"]}], + indirect=True, + ) + async def test_client_credentials_token( + self, test_client: httpx.AsyncClient, registered_client + ): + response = await test_client.post( + "/token", + data={ + "grant_type": "client_credentials", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "scope": "read write", + }, + ) + assert response.status_code == 200 + data = response.json() + assert "access_token" in data