diff --git a/README.md b/README.md index c5fb473ca..6da12e6fc 100644 --- a/README.md +++ b/README.md @@ -1452,6 +1452,8 @@ async def main(): storage=CustomTokenStorage(), redirect_handler=lambda url: print(f"Visit: {url}"), callback_handler=lambda: ("auth_code", None), + # Optional: Initial access token for RFC 7591 Dynamic Client Registration + initial_access_token="your-initial-access-token", ) # Use with streamable HTTP client @@ -1465,6 +1467,42 @@ async def main(): For a complete working example, see [`examples/clients/simple-auth-client/`](examples/clients/simple-auth-client/). +#### Initial Access Tokens + +The SDK supports RFC 7591 Dynamic Client Registration with initial access tokens. This feature provides a multi-level fallback system for obtaining initial access tokens: + +```python +# Method 1: Explicit parameter (highest priority) +oauth_auth = OAuthClientProvider( + server_url="https://api.example.com", + client_metadata=client_metadata, + storage=storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + initial_access_token="your-token", +) + +# Method 2: Provider method override +class CustomOAuthProvider(OAuthClientProvider): + async def initial_access_token(self) -> str | None: + # Custom logic to retrieve token + return await get_token_from_secure_store() + +# Method 3: Environment variable fallback +# Set OAUTH_INITIAL_ACCESS_TOKEN environment variable +# The SDK will automatically use this if no other method provides a token + +# Method 4: No token (default behavior) +# Client registration will proceed without initial access token +``` + +The fallback order is: + +1. Explicit `initial_access_token` parameter +2. Provider's `initial_access_token()` method +3. `OAUTH_INITIAL_ACCESS_TOKEN` environment variable +4. No token (proceeds with standard registration) + ### MCP Primitives The MCP protocol defines three core primitives that servers can implement: diff --git a/examples/clients/simple-auth-client/README.md b/examples/clients/simple-auth-client/README.md index 224040712..56a25391b 100644 --- a/examples/clients/simple-auth-client/README.md +++ b/examples/clients/simple-auth-client/README.md @@ -72,3 +72,4 @@ mcp> quit - `MCP_SERVER_PORT` - Server URL (default: 8000) - `MCP_TRANSPORT_TYPE` - Transport type: `streamable_http` (default) or `sse` +- `OAUTH_INITIAL_ACCESS_TOKEN` - Initial access token for RFC 7591 Dynamic Client Registration (optional) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 06b95dcaa..2349832f3 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -7,6 +7,7 @@ import base64 import hashlib import logging +import os import re import secrets import string @@ -192,6 +193,7 @@ def __init__( redirect_handler: Callable[[str], Awaitable[None]], callback_handler: Callable[[], Awaitable[tuple[str, str | None]]], timeout: float = 300.0, + initial_access_token: str | None = None, ): """Initialize OAuth2 authentication.""" self.context = OAuthContext( @@ -203,6 +205,7 @@ def __init__( timeout=timeout, ) self._initialized = False + self._initial_access_token = initial_access_token def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None: """ @@ -318,8 +321,17 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fal return True # Signal no fallback needed (either success or non-404 error) - async def _register_client(self) -> httpx.Request | None: - """Build registration request or skip if already registered.""" + async def _register_client(self, initial_access_token: str | None = None) -> httpx.Request | None: + """Build registration request or skip if already registered. + + Supports initial access tokens for OAuth 2.0 Dynamic Client Registration according to RFC 7591. + Uses multi-level fallback approach: + + 1. Explicit parameter (highest priority) + 2. Provider's initial_access_token() method + 3. OAUTH_INITIAL_ACCESS_TOKEN environment variable + 4. None (existing behavior for servers that don't require pre-authorization) + """ if self.context.client_info: return None @@ -329,11 +341,29 @@ async def _register_client(self) -> httpx.Request | None: auth_base_url = self.context.get_authorization_base_url(self.context.server_url) registration_url = urljoin(auth_base_url, "/register") + # Multi-level fallback for initial access token + # Level 1: Explicit parameter + token = initial_access_token + + # Level 2: Provider method + if not token: + token = await self.initial_access_token() + + # Level 3: Environment variable + if not token: + token = os.getenv("OAUTH_INITIAL_ACCESS_TOKEN") + + # Level 4: None (current behavior) - no token needed + registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) - return httpx.Request( - "POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"} - ) + headers = {"Content-Type": "application/json"} + + # Add initial access token if available (RFC 7591) + if token: + headers["Authorization"] = f"Bearer {token}" + + return httpx.Request("POST", registration_url, json=registration_data, headers=headers) async def _handle_registration_response(self, response: httpx.Response) -> None: """Handle registration response.""" @@ -506,6 +536,15 @@ async def _initialize(self) -> None: self.context.client_info = await self.context.storage.get_client_info() self._initialized = True + async def initial_access_token(self) -> str | None: + """Provide initial access token for OAuth 2.0 Dynamic Client Registration (RFC 7591).""" + # Return constructor parameter if available + if self._initial_access_token: + return self._initial_access_token + + # Subclasses can override this method to provide tokens from other sources + return None + def _add_auth_header(self, request: httpx.Request) -> None: """Add authorization header to request if we have valid tokens.""" if self.context.current_tokens and self.context.current_tokens.access_token: diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index ea9c16c78..649673f0e 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -416,6 +416,124 @@ async def test_register_client_skip_if_registered(self, oauth_provider, mock_sto request = await oauth_provider._register_client() assert request is None + @pytest.mark.anyio + async def test_register_client_with_explicit_initial_access_token(self, oauth_provider): + """Test client registration with explicit initial access token (highest priority).""" + request = await oauth_provider._register_client(initial_access_token="explicit-token") + + assert request is not None + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/register" + assert request.headers["Content-Type"] == "application/json" + assert request.headers["Authorization"] == "Bearer explicit-token" + + @pytest.mark.anyio + async def test_register_client_with_provider_initial_access_token(self, client_metadata, mock_storage): + """Test client registration with provider method initial access token.""" + + class CustomOAuthProvider(OAuthClientProvider): + async def initial_access_token(self) -> str | None: + return "provider-token" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = CustomOAuthProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + request = await provider._register_client() + + assert request is not None + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/register" + assert request.headers["Content-Type"] == "application/json" + assert request.headers["Authorization"] == "Bearer provider-token" + + @pytest.mark.anyio + async def test_register_client_explicit_overrides_provider(self, client_metadata, mock_storage): + """Test explicit initial access token overrides provider method.""" + + class CustomOAuthProvider(OAuthClientProvider): + async def initial_access_token(self) -> str | None: + return "provider-token" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = CustomOAuthProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + request = await provider._register_client(initial_access_token="explicit-token") + + assert request is not None + assert request.headers["Authorization"] == "Bearer explicit-token" + + @pytest.mark.anyio + async def test_register_client_with_environment_variable(self, oauth_provider, monkeypatch): + """Test client registration with environment variable initial access token.""" + monkeypatch.setenv("OAUTH_INITIAL_ACCESS_TOKEN", "env-token") + + request = await oauth_provider._register_client() + + assert request is not None + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/register" + assert request.headers["Content-Type"] == "application/json" + assert request.headers["Authorization"] == "Bearer env-token" + + @pytest.mark.anyio + async def test_register_client_without_initial_access_token(self, oauth_provider): + """Test client registration without initial access token (backward compatibility).""" + request = await oauth_provider._register_client() + + assert request is not None + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/register" + assert request.headers["Content-Type"] == "application/json" + assert "Authorization" not in request.headers + + @pytest.mark.anyio + async def test_initial_access_token_constructor_parameter(self, client_metadata, mock_storage): + """Test OAuthClientProvider with initial access token constructor parameter.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + initial_access_token="constructor-token", + ) + + token = await provider.initial_access_token() + assert token == "constructor-token" + + request = await provider._register_client() + assert request is not None + assert request.headers["Authorization"] == "Bearer constructor-token" + @pytest.mark.anyio async def test_token_exchange_request(self, oauth_provider): """Test token exchange request building."""