Skip to content

feat: add multi-level initial access token support for OAuth 2.0 Dynamic Client Registration (RFC 7591) #1154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,8 @@ async def main():
storage=CustomTokenStorage(),
redirect_handler=lambda url: print(f"Visit: {url}"),
callback_handler=lambda: ("auth_code", None),
# Optional: Initial access token for RFC 7591 Dynamic Client Registration
initial_access_token="your-initial-access-token",
)

# Use with streamable HTTP client
Expand All @@ -1465,6 +1467,42 @@ async def main():

For a complete working example, see [`examples/clients/simple-auth-client/`](examples/clients/simple-auth-client/).

#### Initial Access Tokens

The SDK supports RFC 7591 Dynamic Client Registration with initial access tokens. This feature provides a multi-level fallback system for obtaining initial access tokens:

```python
# Method 1: Explicit parameter (highest priority)
oauth_auth = OAuthClientProvider(
server_url="https://api.example.com",
client_metadata=client_metadata,
storage=storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
initial_access_token="your-token",
)

# Method 2: Provider method override
class CustomOAuthProvider(OAuthClientProvider):
async def initial_access_token(self) -> str | None:
# Custom logic to retrieve token
return await get_token_from_secure_store()

# Method 3: Environment variable fallback
# Set OAUTH_INITIAL_ACCESS_TOKEN environment variable
# The SDK will automatically use this if no other method provides a token

# Method 4: No token (default behavior)
# Client registration will proceed without initial access token
```

The fallback order is:

1. Explicit `initial_access_token` parameter
2. Provider's `initial_access_token()` method
3. `OAUTH_INITIAL_ACCESS_TOKEN` environment variable
4. No token (proceeds with standard registration)

### MCP Primitives

The MCP protocol defines three core primitives that servers can implement:
Expand Down
1 change: 1 addition & 0 deletions examples/clients/simple-auth-client/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,4 @@ mcp> quit

- `MCP_SERVER_PORT` - Server URL (default: 8000)
- `MCP_TRANSPORT_TYPE` - Transport type: `streamable_http` (default) or `sse`
- `OAUTH_INITIAL_ACCESS_TOKEN` - Initial access token for RFC 7591 Dynamic Client Registration (optional)
49 changes: 44 additions & 5 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import base64
import hashlib
import logging
import os
import re
import secrets
import string
Expand Down Expand Up @@ -192,6 +193,7 @@ def __init__(
redirect_handler: Callable[[str], Awaitable[None]],
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]],
timeout: float = 300.0,
initial_access_token: str | None = None,
):
"""Initialize OAuth2 authentication."""
self.context = OAuthContext(
Expand All @@ -203,6 +205,7 @@ def __init__(
timeout=timeout,
)
self._initialized = False
self._initial_access_token = initial_access_token

def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None:
"""
Expand Down Expand Up @@ -318,8 +321,17 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fal

return True # Signal no fallback needed (either success or non-404 error)

async def _register_client(self) -> httpx.Request | None:
"""Build registration request or skip if already registered."""
async def _register_client(self, initial_access_token: str | None = None) -> httpx.Request | None:
"""Build registration request or skip if already registered.

Supports initial access tokens for OAuth 2.0 Dynamic Client Registration according to RFC 7591.
Uses multi-level fallback approach:

1. Explicit parameter (highest priority)
2. Provider's initial_access_token() method
3. OAUTH_INITIAL_ACCESS_TOKEN environment variable
4. None (existing behavior for servers that don't require pre-authorization)
"""
if self.context.client_info:
return None

Expand All @@ -329,11 +341,29 @@ async def _register_client(self) -> httpx.Request | None:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
registration_url = urljoin(auth_base_url, "/register")

# Multi-level fallback for initial access token
# Level 1: Explicit parameter
token = initial_access_token

# Level 2: Provider method
if not token:
token = await self.initial_access_token()

# Level 3: Environment variable
if not token:
token = os.getenv("OAUTH_INITIAL_ACCESS_TOKEN")

# Level 4: None (current behavior) - no token needed

registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)

return httpx.Request(
"POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}
)
headers = {"Content-Type": "application/json"}

# Add initial access token if available (RFC 7591)
if token:
headers["Authorization"] = f"Bearer {token}"

return httpx.Request("POST", registration_url, json=registration_data, headers=headers)

async def _handle_registration_response(self, response: httpx.Response) -> None:
"""Handle registration response."""
Expand Down Expand Up @@ -506,6 +536,15 @@ async def _initialize(self) -> None:
self.context.client_info = await self.context.storage.get_client_info()
self._initialized = True

async def initial_access_token(self) -> str | None:
"""Provide initial access token for OAuth 2.0 Dynamic Client Registration (RFC 7591)."""
# Return constructor parameter if available
if self._initial_access_token:
return self._initial_access_token

# Subclasses can override this method to provide tokens from other sources
return None

def _add_auth_header(self, request: httpx.Request) -> None:
"""Add authorization header to request if we have valid tokens."""
if self.context.current_tokens and self.context.current_tokens.access_token:
Expand Down
118 changes: 118 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,124 @@ async def test_register_client_skip_if_registered(self, oauth_provider, mock_sto
request = await oauth_provider._register_client()
assert request is None

@pytest.mark.anyio
async def test_register_client_with_explicit_initial_access_token(self, oauth_provider):
"""Test client registration with explicit initial access token (highest priority)."""
request = await oauth_provider._register_client(initial_access_token="explicit-token")

assert request is not None
assert request.method == "POST"
assert str(request.url) == "https://api.example.com/register"
assert request.headers["Content-Type"] == "application/json"
assert request.headers["Authorization"] == "Bearer explicit-token"

@pytest.mark.anyio
async def test_register_client_with_provider_initial_access_token(self, client_metadata, mock_storage):
"""Test client registration with provider method initial access token."""

class CustomOAuthProvider(OAuthClientProvider):
async def initial_access_token(self) -> str | None:
return "provider-token"

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = CustomOAuthProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

request = await provider._register_client()

assert request is not None
assert request.method == "POST"
assert str(request.url) == "https://api.example.com/register"
assert request.headers["Content-Type"] == "application/json"
assert request.headers["Authorization"] == "Bearer provider-token"

@pytest.mark.anyio
async def test_register_client_explicit_overrides_provider(self, client_metadata, mock_storage):
"""Test explicit initial access token overrides provider method."""

class CustomOAuthProvider(OAuthClientProvider):
async def initial_access_token(self) -> str | None:
return "provider-token"

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = CustomOAuthProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

request = await provider._register_client(initial_access_token="explicit-token")

assert request is not None
assert request.headers["Authorization"] == "Bearer explicit-token"

@pytest.mark.anyio
async def test_register_client_with_environment_variable(self, oauth_provider, monkeypatch):
"""Test client registration with environment variable initial access token."""
monkeypatch.setenv("OAUTH_INITIAL_ACCESS_TOKEN", "env-token")

request = await oauth_provider._register_client()

assert request is not None
assert request.method == "POST"
assert str(request.url) == "https://api.example.com/register"
assert request.headers["Content-Type"] == "application/json"
assert request.headers["Authorization"] == "Bearer env-token"

@pytest.mark.anyio
async def test_register_client_without_initial_access_token(self, oauth_provider):
"""Test client registration without initial access token (backward compatibility)."""
request = await oauth_provider._register_client()

assert request is not None
assert request.method == "POST"
assert str(request.url) == "https://api.example.com/register"
assert request.headers["Content-Type"] == "application/json"
assert "Authorization" not in request.headers

@pytest.mark.anyio
async def test_initial_access_token_constructor_parameter(self, client_metadata, mock_storage):
"""Test OAuthClientProvider with initial access token constructor parameter."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
initial_access_token="constructor-token",
)

token = await provider.initial_access_token()
assert token == "constructor-token"

request = await provider._register_client()
assert request is not None
assert request.headers["Authorization"] == "Bearer constructor-token"

@pytest.mark.anyio
async def test_token_exchange_request(self, oauth_provider):
"""Test token exchange request building."""
Expand Down
Loading