From b609fc5025149f44c8ae0ff56a430b5aa2a4def4 Mon Sep 17 00:00:00 2001 From: Andor Markus Date: Wed, 16 Jul 2025 12:51:31 +0200 Subject: [PATCH 1/3] feat: add initial access token support for OAuth 2.0 Dynamic Client Registration (RFC 7591) - Add initial_access_token parameter to OAuthClientProvider constructor - Implement multi-level fallback for token resolution: 1. Explicit parameter (highest priority) 2. Provider method (initial_access_token()) 3. Environment variable (OAUTH_INITIAL_ACCESS_TOKEN) 4. No token (existing behavior) - Add Authorization Bearer header to registration requests when token available - Add comprehensive test coverage for all fallback scenarios - Update documentation with usage examples and configuration details - Maintain full backward compatibility with existing OAuth flows This enables clients to register with protected OAuth endpoints that require initial access tokens per RFC 7591 Dynamic Client Registration specification. --- README.md | 37 ++ diff_with_main.txt | 577 ++++++++++++++++++ examples/clients/simple-auth-client/README.md | 1 + src/mcp/client/auth.py | 49 +- tests/client/test_auth.py | 118 ++++ 5 files changed, 777 insertions(+), 5 deletions(-) create mode 100644 diff_with_main.txt diff --git a/README.md b/README.md index c5fb473ca..05b5a2732 100644 --- a/README.md +++ b/README.md @@ -1452,6 +1452,8 @@ async def main(): storage=CustomTokenStorage(), redirect_handler=lambda url: print(f"Visit: {url}"), callback_handler=lambda: ("auth_code", None), + # Optional: Initial access token for RFC 7591 Dynamic Client Registration + initial_access_token="your-initial-access-token", ) # Use with streamable HTTP client @@ -1465,6 +1467,41 @@ async def main(): For a complete working example, see [`examples/clients/simple-auth-client/`](examples/clients/simple-auth-client/). +#### Initial Access Tokens + +The SDK supports RFC 7591 Dynamic Client Registration with initial access tokens. This feature provides a multi-level fallback system for obtaining initial access tokens: + +```python +# Method 1: Explicit parameter (highest priority) +oauth_auth = OAuthClientProvider( + server_url="https://api.example.com", + client_metadata=client_metadata, + storage=storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + initial_access_token="your-token", +) + +# Method 2: Provider method override +class CustomOAuthProvider(OAuthClientProvider): + async def initial_access_token(self) -> str | None: + # Custom logic to retrieve token + return await get_token_from_secure_store() + +# Method 3: Environment variable fallback +# Set OAUTH_INITIAL_ACCESS_TOKEN environment variable +# The SDK will automatically use this if no other method provides a token + +# Method 4: No token (default behavior) +# Client registration will proceed without initial access token +``` + +The fallback order is: +1. Explicit `initial_access_token` parameter +2. Provider's `initial_access_token()` method +3. `OAUTH_INITIAL_ACCESS_TOKEN` environment variable +4. No token (proceeds with standard registration) + ### MCP Primitives The MCP protocol defines three core primitives that servers can implement: diff --git a/diff_with_main.txt b/diff_with_main.txt new file mode 100644 index 000000000..d91174b88 --- /dev/null +++ b/diff_with_main.txt @@ -0,0 +1,577 @@ +diff --git a/package-lock.json b/package-lock.json +index 01bc095..fa1bde0 100644 +--- a/package-lock.json ++++ b/package-lock.json +@@ -1,12 +1,12 @@ + { + "name": "@modelcontextprotocol/sdk", +- "version": "1.15.0", ++ "version": "1.15.1", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "@modelcontextprotocol/sdk", +- "version": "1.15.0", ++ "version": "1.15.1", + "license": "MIT", + "dependencies": { + "ajv": "^6.12.6", +diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts +index ce0cc70..eb26abc 100644 +--- a/src/client/auth.test.ts ++++ b/src/client/auth.test.ts +@@ -1158,6 +1158,140 @@ describe("OAuth Authorization", () => { + }) + ).rejects.toThrow("Dynamic client registration failed"); + }); ++ ++ describe("initial access token support", () => { ++ it("includes initial access token from explicit parameter", async () => { ++ mockFetch.mockResolvedValueOnce({ ++ ok: true, ++ status: 200, ++ json: async () => validClientInfo, ++ }); ++ ++ await registerClient("https://auth.example.com", { ++ clientMetadata: validClientMetadata, ++ initialAccessToken: "explicit-token", ++ }); ++ ++ expect(mockFetch).toHaveBeenCalledWith( ++ expect.objectContaining({ ++ href: "https://auth.example.com/register", ++ }), ++ expect.objectContaining({ ++ method: "POST", ++ headers: { ++ "Content-Type": "application/json", ++ "Authorization": "Bearer explicit-token", ++ }, ++ body: JSON.stringify(validClientMetadata), ++ }) ++ ); ++ }); ++ ++ it("includes initial access token from provider method", async () => { ++ const mockProvider: OAuthClientProvider = { ++ get redirectUrl() { return "http://localhost:3000/callback"; }, ++ get clientMetadata() { return validClientMetadata; }, ++ clientInformation: jest.fn(), ++ tokens: jest.fn(), ++ saveTokens: jest.fn(), ++ redirectToAuthorization: jest.fn(), ++ saveCodeVerifier: jest.fn(), ++ codeVerifier: jest.fn(), ++ initialAccessToken: jest.fn().mockResolvedValue("provider-token"), ++ }; ++ ++ mockFetch.mockResolvedValueOnce({ ++ ok: true, ++ status: 200, ++ json: async () => validClientInfo, ++ }); ++ ++ await registerClient("https://auth.example.com", { ++ clientMetadata: validClientMetadata, ++ provider: mockProvider, ++ }); ++ ++ expect(mockFetch).toHaveBeenCalledWith( ++ expect.objectContaining({ ++ href: "https://auth.example.com/register", ++ }), ++ expect.objectContaining({ ++ method: "POST", ++ headers: { ++ "Content-Type": "application/json", ++ "Authorization": "Bearer provider-token", ++ }, ++ body: JSON.stringify(validClientMetadata), ++ }) ++ ); ++ }); ++ ++ it("prioritizes explicit parameter over provider method", async () => { ++ const mockProvider: OAuthClientProvider = { ++ get redirectUrl() { return "http://localhost:3000/callback"; }, ++ get clientMetadata() { return validClientMetadata; }, ++ clientInformation: jest.fn(), ++ tokens: jest.fn(), ++ saveTokens: jest.fn(), ++ redirectToAuthorization: jest.fn(), ++ saveCodeVerifier: jest.fn(), ++ codeVerifier: jest.fn(), ++ initialAccessToken: jest.fn().mockResolvedValue("provider-token"), ++ }; ++ ++ mockFetch.mockResolvedValueOnce({ ++ ok: true, ++ status: 200, ++ json: async () => validClientInfo, ++ }); ++ ++ await registerClient("https://auth.example.com", { ++ clientMetadata: validClientMetadata, ++ initialAccessToken: "explicit-token", ++ provider: mockProvider, ++ }); ++ ++ expect(mockProvider.initialAccessToken).not.toHaveBeenCalled(); ++ expect(mockFetch).toHaveBeenCalledWith( ++ expect.objectContaining({ ++ href: "https://auth.example.com/register", ++ }), ++ expect.objectContaining({ ++ method: "POST", ++ headers: { ++ "Content-Type": "application/json", ++ "Authorization": "Bearer explicit-token", ++ }, ++ body: JSON.stringify(validClientMetadata), ++ }) ++ ); ++ }); ++ ++ it("registers without authorization header when no token available", async () => { ++ mockFetch.mockResolvedValueOnce({ ++ ok: true, ++ status: 200, ++ json: async () => validClientInfo, ++ }); ++ ++ await registerClient("https://auth.example.com", { ++ clientMetadata: validClientMetadata, ++ }); ++ ++ expect(mockFetch).toHaveBeenCalledWith( ++ expect.objectContaining({ ++ href: "https://auth.example.com/register", ++ }), ++ expect.objectContaining({ ++ method: "POST", ++ headers: { ++ "Content-Type": "application/json", ++ }, ++ body: JSON.stringify(validClientMetadata), ++ }) ++ ); ++ }); ++ }); + }); + + describe("auth function", () => { +diff --git a/src/client/auth.ts b/src/client/auth.ts +index 4a8bbe2..a3e937c 100644 +--- a/src/client/auth.ts ++++ b/src/client/auth.ts +@@ -124,6 +124,17 @@ export interface OAuthClientProvider { + * This avoids requiring the user to intervene manually. + */ + invalidateCredentials?(scope: 'all' | 'client' | 'tokens' | 'verifier'): void | Promise; ++ ++ /** ++ * If implemented, provides an initial access token for OAuth 2.0 Dynamic Client Registration ++ * according to RFC 7591. This token is used to authorize the client registration request. ++ * ++ * The initial access token allows the client to register with authorization servers that ++ * require pre-authorization for dynamic client registration. ++ * ++ * @returns The initial access token string, or undefined if none is available ++ */ ++ initialAccessToken?(): string | undefined | Promise; + } + + export type AuthResult = "AUTHORIZED" | "REDIRECT"; +@@ -281,7 +292,8 @@ export async function auth( + serverUrl: string | URL; + authorizationCode?: string; + scope?: string; +- resourceMetadataUrl?: URL }): Promise { ++ resourceMetadataUrl?: URL; ++ initialAccessToken?: string; }): Promise { + + try { + return await authInternal(provider, options); +@@ -305,12 +317,14 @@ async function authInternal( + { serverUrl, + authorizationCode, + scope, +- resourceMetadataUrl ++ resourceMetadataUrl, ++ initialAccessToken + }: { + serverUrl: string | URL; + authorizationCode?: string; + scope?: string; +- resourceMetadataUrl?: URL ++ resourceMetadataUrl?: URL; ++ initialAccessToken?: string; + }): Promise { + + let resourceMetadata: OAuthProtectedResourceMetadata | undefined; +@@ -344,6 +358,8 @@ async function authInternal( + const fullInformation = await registerClient(authorizationServerUrl, { + metadata, + clientMetadata: provider.clientMetadata, ++ initialAccessToken, ++ provider, + }); + + await provider.saveClientInformation(fullInformation); +@@ -877,15 +893,28 @@ export async function refreshAuthorization( + + /** + * Performs OAuth 2.0 Dynamic Client Registration according to RFC 7591. ++ * ++ * Supports initial access tokens for authorization servers that require ++ * pre-authorization for dynamic client registration. The initial access token ++ * is resolved using a multi-level fallback approach: ++ * ++ * 1. Explicit `initialAccessToken` parameter (highest priority) ++ * 2. Provider's `initialAccessToken()` method (if implemented) ++ * 3. `OAUTH_INITIAL_ACCESS_TOKEN` environment variable ++ * 4. None (current behavior for servers that don't require pre-authorization) + */ + export async function registerClient( + authorizationServerUrl: string | URL, + { + metadata, + clientMetadata, ++ initialAccessToken, ++ provider, + }: { + metadata?: OAuthMetadata; + clientMetadata: OAuthClientMetadata; ++ initialAccessToken?: string; ++ provider?: OAuthClientProvider; + }, + ): Promise { + let registrationUrl: URL; +@@ -900,11 +929,33 @@ export async function registerClient( + registrationUrl = new URL("/register", authorizationServerUrl); + } + ++ // Multi-level fallback for initial access token ++ let token = initialAccessToken; // Level 1: Explicit parameter ++ ++ if (!token && provider?.initialAccessToken) { ++ // Level 2: Provider method ++ token = await Promise.resolve(provider.initialAccessToken()); ++ } ++ ++ // Level 3: Environment variable (Node.js environments only) ++ if (!token && typeof globalThis !== 'undefined' && (globalThis as any).process?.env) { ++ token = (globalThis as any).process.env.OAUTH_INITIAL_ACCESS_TOKEN; ++ } ++ ++ // Level 4: None (current behavior) - no token needed ++ ++ const headers: Record = { ++ "Content-Type": "application/json", ++ }; ++ ++ // Add initial access token if available (RFC 7591) ++ if (token) { ++ headers["Authorization"] = `Bearer ${token}`; ++ } ++ + const response = await fetch(registrationUrl, { + method: "POST", +- headers: { +- "Content-Type": "application/json", +- }, ++ headers, + body: JSON.stringify(clientMetadata), + }); + +diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts +index 2cc4a1d..d8cadfb 100644 +--- a/src/client/sse.test.ts ++++ b/src/client/sse.test.ts +@@ -1107,5 +1107,80 @@ describe("SSEClientTransport", () => { + await expect(() => transport.start()).rejects.toThrow(InvalidGrantError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('tokens'); + }); ++ ++ describe("initialAccessToken support", () => { ++ it("stores initialAccessToken from constructor options", () => { ++ const transport = new SSEClientTransport( ++ new URL("http://localhost:1234/mcp"), ++ { initialAccessToken: "test-initial-token" } ++ ); ++ ++ // Access private property for testing ++ const transportInstance = transport as unknown as { _initialAccessToken?: string }; ++ expect(transportInstance._initialAccessToken).toBe("test-initial-token"); ++ }); ++ ++ it("works without initialAccessToken (backward compatibility)", async () => { ++ const transport = new SSEClientTransport( ++ new URL("http://localhost:1234/mcp"), ++ { authProvider: mockAuthProvider } ++ ); ++ ++ const transportInstance = transport as unknown as { _initialAccessToken?: string }; ++ expect(transportInstance._initialAccessToken).toBeUndefined(); ++ ++ // Should not throw when no initial access token provided ++ expect(() => transport).not.toThrow(); ++ }); ++ ++ it("includes initialAccessToken in auth calls", async () => { ++ // Create a spy on the auth module ++ const authModule = await import("./auth.js"); ++ const authSpy = jest.spyOn(authModule, "auth").mockResolvedValue("REDIRECT"); ++ ++ const transport = new SSEClientTransport( ++ resourceBaseUrl, ++ { ++ authProvider: mockAuthProvider, ++ initialAccessToken: "test-initial-token" ++ } ++ ); ++ ++ // Start the transport first ++ await transport.start(); ++ ++ // Mock fetch to return 401 and trigger auth on send ++ const originalFetch = global.fetch; ++ global.fetch = jest.fn().mockResolvedValueOnce({ ++ ok: false, ++ status: 401, ++ headers: new Headers(), ++ }); ++ ++ const message = { ++ jsonrpc: "2.0" as const, ++ method: "test", ++ params: {}, ++ id: "test-id" ++ }; ++ ++ try { ++ await transport.send(message); ++ } catch { ++ // Expected to fail due to mock setup, we're just testing auth call ++ } ++ ++ expect(authSpy).toHaveBeenCalledWith( ++ mockAuthProvider, ++ expect.objectContaining({ ++ initialAccessToken: "test-initial-token" ++ }) ++ ); ++ ++ // Restore fetch and spy ++ global.fetch = originalFetch; ++ authSpy.mockRestore(); ++ }); ++ }); + }); + }); +diff --git a/src/client/sse.ts b/src/client/sse.ts +index 568a515..98484bf 100644 +--- a/src/client/sse.ts ++++ b/src/client/sse.ts +@@ -52,6 +52,16 @@ export type SSEClientTransportOptions = { + * Custom fetch implementation used for all network requests. + */ + fetch?: FetchLike; ++ ++ /** ++ * Initial access token for OAuth 2.0 Dynamic Client Registration (RFC 7591). ++ * This token is used to authorize the client registration request with authorization servers ++ * that require pre-authorization for dynamic client registration. ++ * ++ * If not provided, the system will fall back to the provider's `initialAccessToken()` method ++ * and then to the `OAUTH_INITIAL_ACCESS_TOKEN` environment variable. ++ */ ++ initialAccessToken?: string; + }; + + /** +@@ -69,6 +79,7 @@ export class SSEClientTransport implements Transport { + private _authProvider?: OAuthClientProvider; + private _fetch?: FetchLike; + private _protocolVersion?: string; ++ private _initialAccessToken?: string; + + onclose?: () => void; + onerror?: (error: Error) => void; +@@ -84,6 +95,7 @@ export class SSEClientTransport implements Transport { + this._requestInit = opts?.requestInit; + this._authProvider = opts?.authProvider; + this._fetch = opts?.fetch; ++ this._initialAccessToken = opts?.initialAccessToken; + } + + private async _authThenStart(): Promise { +@@ -93,7 +105,7 @@ export class SSEClientTransport implements Transport { + + let result: AuthResult; + try { +- result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); ++ result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, initialAccessToken: this._initialAccessToken }); + } catch (error) { + this.onerror?.(error as Error); + throw error; +@@ -218,7 +230,7 @@ export class SSEClientTransport implements Transport { + throw new UnauthorizedError("No auth provider"); + } + +- const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl }); ++ const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, initialAccessToken: this._initialAccessToken }); + if (result !== "AUTHORIZED") { + throw new UnauthorizedError("Failed to authorize"); + } +@@ -252,7 +264,7 @@ const response = await (this._fetch ?? fetch)(this._endpoint, init); + + this._resourceMetadataUrl = extractResourceMetadataUrl(response); + +- const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); ++ const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, initialAccessToken: this._initialAccessToken }); + if (result !== "AUTHORIZED") { + throw new UnauthorizedError(); + } +diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts +index c54cf28..baeb955 100644 +--- a/src/client/streamableHttp.test.ts ++++ b/src/client/streamableHttp.test.ts +@@ -855,4 +855,73 @@ describe("StreamableHTTPClientTransport", () => { + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('tokens'); + }); ++ ++ describe("initialAccessToken support", () => { ++ it("stores initialAccessToken from constructor options", () => { ++ const transport = new StreamableHTTPClientTransport( ++ new URL("http://localhost:1234/mcp"), ++ { initialAccessToken: "test-initial-token" } ++ ); ++ ++ // Access private property for testing ++ const transportInstance = transport as unknown as { _initialAccessToken?: string }; ++ expect(transportInstance._initialAccessToken).toBe("test-initial-token"); ++ }); ++ ++ it("works without initialAccessToken (backward compatibility)", async () => { ++ const transport = new StreamableHTTPClientTransport( ++ new URL("http://localhost:1234/mcp"), ++ { authProvider: mockAuthProvider } ++ ); ++ ++ const transportInstance = transport as unknown as { _initialAccessToken?: string }; ++ expect(transportInstance._initialAccessToken).toBeUndefined(); ++ ++ // Should not throw when no initial access token provided ++ expect(() => transport).not.toThrow(); ++ }); ++ ++ it("includes initialAccessToken in auth calls", async () => { ++ // Create a spy on the auth module ++ const authModule = await import("./auth.js"); ++ const authSpy = jest.spyOn(authModule, "auth").mockResolvedValue("REDIRECT"); ++ ++ const transport = new StreamableHTTPClientTransport( ++ new URL("http://localhost:1234/mcp"), ++ { ++ authProvider: mockAuthProvider, ++ initialAccessToken: "test-initial-token" ++ } ++ ); ++ ++ // Mock fetch to trigger auth flow on send (401 response) ++ (global.fetch as jest.Mock).mockResolvedValueOnce({ ++ ok: false, ++ status: 401, ++ headers: new Headers(), ++ }); ++ ++ const message = { ++ jsonrpc: "2.0" as const, ++ method: "test", ++ params: {}, ++ id: "test-id" ++ }; ++ ++ try { ++ await transport.send(message); ++ } catch { ++ // Expected to fail due to mock setup, we're just testing auth call ++ } ++ ++ expect(authSpy).toHaveBeenCalledWith( ++ mockAuthProvider, ++ expect.objectContaining({ ++ initialAccessToken: "test-initial-token" ++ }) ++ ); ++ ++ authSpy.mockRestore(); ++ }); ++ }); + }); +diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts +index b0894fc..a790372 100644 +--- a/src/client/streamableHttp.ts ++++ b/src/client/streamableHttp.ts +@@ -114,6 +114,14 @@ export type StreamableHTTPClientTransportOptions = { + * When not provided and connecting to a server that supports session IDs, the server will generate a new session ID. + */ + sessionId?: string; ++ ++ /** ++ * Initial access token for OAuth 2.0 Dynamic Client Registration (RFC 7591). ++ * This token is used to authorize the client registration request with authorization servers that require pre-authorization for dynamic client registration. ++ * ++ * If not provided, the system will fall back to the provider's `initialAccessToken()` method and then to the `OAUTH_INITIAL_ACCESS_TOKEN` environment variable. ++ */ ++ initialAccessToken?: string; + }; + + /** +@@ -131,6 +139,7 @@ export class StreamableHTTPClientTransport implements Transport { + private _sessionId?: string; + private _reconnectionOptions: StreamableHTTPReconnectionOptions; + private _protocolVersion?: string; ++ private _initialAccessToken?: string; + + onclose?: () => void; + onerror?: (error: Error) => void; +@@ -147,6 +156,7 @@ export class StreamableHTTPClientTransport implements Transport { + this._fetch = opts?.fetch; + this._sessionId = opts?.sessionId; + this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; ++ this._initialAccessToken = opts?.initialAccessToken; + } + + private async _authThenStart(): Promise { +@@ -156,7 +166,7 @@ export class StreamableHTTPClientTransport implements Transport { + + let result: AuthResult; + try { +- result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); ++ result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, initialAccessToken: this._initialAccessToken }); + } catch (error) { + this.onerror?.(error as Error); + throw error; +@@ -392,7 +402,7 @@ const response = await (this._fetch ?? fetch)(this._url, { + throw new UnauthorizedError("No auth provider"); + } + +- const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl }); ++ const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, initialAccessToken: this._initialAccessToken }); + if (result !== "AUTHORIZED") { + throw new UnauthorizedError("Failed to authorize"); + } +@@ -440,7 +450,7 @@ const response = await (this._fetch ?? fetch)(this._url, init); + + this._resourceMetadataUrl = extractResourceMetadataUrl(response); + +- const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); ++ const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, initialAccessToken: this._initialAccessToken }); + if (result !== "AUTHORIZED") { + throw new UnauthorizedError(); + } diff --git a/examples/clients/simple-auth-client/README.md b/examples/clients/simple-auth-client/README.md index 224040712..56a25391b 100644 --- a/examples/clients/simple-auth-client/README.md +++ b/examples/clients/simple-auth-client/README.md @@ -72,3 +72,4 @@ mcp> quit - `MCP_SERVER_PORT` - Server URL (default: 8000) - `MCP_TRANSPORT_TYPE` - Transport type: `streamable_http` (default) or `sse` +- `OAUTH_INITIAL_ACCESS_TOKEN` - Initial access token for RFC 7591 Dynamic Client Registration (optional) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 06b95dcaa..c14f885d2 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -7,6 +7,7 @@ import base64 import hashlib import logging +import os import re import secrets import string @@ -192,6 +193,7 @@ def __init__( redirect_handler: Callable[[str], Awaitable[None]], callback_handler: Callable[[], Awaitable[tuple[str, str | None]]], timeout: float = 300.0, + initial_access_token: str | None = None, ): """Initialize OAuth2 authentication.""" self.context = OAuthContext( @@ -203,6 +205,7 @@ def __init__( timeout=timeout, ) self._initialized = False + self._initial_access_token = initial_access_token def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None: """ @@ -318,8 +321,17 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fal return True # Signal no fallback needed (either success or non-404 error) - async def _register_client(self) -> httpx.Request | None: - """Build registration request or skip if already registered.""" + async def _register_client(self, initial_access_token: str | None = None) -> httpx.Request | None: + """Build registration request or skip if already registered. + + Supports initial access tokens for OAuth 2.0 Dynamic Client Registration according to RFC 7591. + Uses multi-level fallback approach: + + 1. Explicit parameter (highest priority) + 2. Provider's initial_access_token() method + 3. OAUTH_INITIAL_ACCESS_TOKEN environment variable + 4. None (existing behavior for servers that don't require pre-authorization) + """ if self.context.client_info: return None @@ -329,11 +341,29 @@ async def _register_client(self) -> httpx.Request | None: auth_base_url = self.context.get_authorization_base_url(self.context.server_url) registration_url = urljoin(auth_base_url, "/register") + # Multi-level fallback for initial access token + # Level 1: Explicit parameter + token = initial_access_token + + # Level 2: Provider method + if not token: + token = await self.initial_access_token() + + # Level 3: Environment variable + if not token: + token = os.getenv("OAUTH_INITIAL_ACCESS_TOKEN") + + # Level 4: None (current behavior) - no token needed + registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) - return httpx.Request( - "POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"} - ) + headers = {"Content-Type": "application/json"} + + # Add initial access token if available (RFC 7591) + if token: + headers["Authorization"] = f"Bearer {token}" + + return httpx.Request("POST", registration_url, json=registration_data, headers=headers) async def _handle_registration_response(self, response: httpx.Response) -> None: """Handle registration response.""" @@ -506,6 +536,15 @@ async def _initialize(self) -> None: self.context.client_info = await self.context.storage.get_client_info() self._initialized = True + async def initial_access_token(self) -> str | None: + """Provide initial access token for OAuth 2.0 Dynamic Client Registration (RFC 7591).""" + # Return constructor parameter if available + if self._initial_access_token: + return self._initial_access_token + + # Subclasses can override this method to provide tokens from other sources + return None + def _add_auth_header(self, request: httpx.Request) -> None: """Add authorization header to request if we have valid tokens.""" if self.context.current_tokens and self.context.current_tokens.access_token: diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index ea9c16c78..7d5965ce3 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -416,6 +416,124 @@ async def test_register_client_skip_if_registered(self, oauth_provider, mock_sto request = await oauth_provider._register_client() assert request is None + @pytest.mark.anyio + async def test_register_client_with_explicit_initial_access_token(self, oauth_provider): + """Test client registration with explicit initial access token (highest priority).""" + request = await oauth_provider._register_client(initial_access_token="explicit-token") + + assert request is not None + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/register" + assert request.headers["Content-Type"] == "application/json" + assert request.headers["Authorization"] == "Bearer explicit-token" + + @pytest.mark.anyio + async def test_register_client_with_provider_initial_access_token(self, client_metadata, mock_storage): + """Test client registration with provider method initial access token.""" + + class CustomOAuthProvider(OAuthClientProvider): + async def initial_access_token(self) -> str | None: + return "provider-token" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = CustomOAuthProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + request = await provider._register_client() + + assert request is not None + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/register" + assert request.headers["Content-Type"] == "application/json" + assert request.headers["Authorization"] == "Bearer provider-token" + + @pytest.mark.anyio + async def test_register_client_explicit_overrides_provider(self, client_metadata, mock_storage): + """Test explicit initial access token overrides provider method.""" + + class CustomOAuthProvider(OAuthClientProvider): + async def initial_access_token(self) -> str | None: + return "provider-token" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = CustomOAuthProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + request = await provider._register_client(initial_access_token="explicit-token") + + assert request is not None + assert request.headers["Authorization"] == "Bearer explicit-token" + + @pytest.mark.anyio + async def test_register_client_with_environment_variable(self, oauth_provider, monkeypatch): + """Test client registration with environment variable initial access token.""" + monkeypatch.setenv("OAUTH_INITIAL_ACCESS_TOKEN", "env-token") + + request = await oauth_provider._register_client() + + assert request is not None + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/register" + assert request.headers["Content-Type"] == "application/json" + assert request.headers["Authorization"] == "Bearer env-token" + + @pytest.mark.anyio + async def test_register_client_without_initial_access_token(self, oauth_provider): + """Test client registration without initial access token (backward compatibility).""" + request = await oauth_provider._register_client() + + assert request is not None + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/register" + assert request.headers["Content-Type"] == "application/json" + assert "Authorization" not in request.headers + + @pytest.mark.anyio + async def test_initial_access_token_constructor_parameter(self, client_metadata, mock_storage): + """Test OAuthClientProvider with initial access token constructor parameter.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + initial_access_token="constructor-token", + ) + + token = await provider.initial_access_token() + assert token == "constructor-token" + + request = await provider._register_client() + assert request is not None + assert request.headers["Authorization"] == "Bearer constructor-token" + @pytest.mark.anyio async def test_token_exchange_request(self, oauth_provider): """Test token exchange request building.""" From 100590c65b72bb82f26af25c129cad28fd608956 Mon Sep 17 00:00:00 2001 From: Andor Markus Date: Wed, 16 Jul 2025 12:58:56 +0200 Subject: [PATCH 2/3] feat: add initial access token support for OAuth 2.0 Dynamic Client Registration (RFC 7591) - Add initial_access_token parameter to OAuthClientProvider constructor - Implement multi-level fallback for token resolution: 1. Explicit parameter (highest priority) 2. Provider method (initial_access_token()) 3. Environment variable (OAUTH_INITIAL_ACCESS_TOKEN) 4. No token (existing behavior) - Add Authorization Bearer header to registration requests when token available - Add comprehensive test coverage for all fallback scenarios - Update documentation with usage examples and configuration details - Maintain full backward compatibility with existing OAuth flows This enables clients to register with protected OAuth endpoints that require initial access tokens per RFC 7591 Dynamic Client Registration specification. --- README.md | 1 + src/mcp/client/auth.py | 6 +++--- tests/client/test_auth.py | 6 +++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 05b5a2732..6da12e6fc 100644 --- a/README.md +++ b/README.md @@ -1497,6 +1497,7 @@ class CustomOAuthProvider(OAuthClientProvider): ``` The fallback order is: + 1. Explicit `initial_access_token` parameter 2. Provider's `initial_access_token()` method 3. `OAUTH_INITIAL_ACCESS_TOKEN` environment variable diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index c14f885d2..2349832f3 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -323,10 +323,10 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fal async def _register_client(self, initial_access_token: str | None = None) -> httpx.Request | None: """Build registration request or skip if already registered. - + Supports initial access tokens for OAuth 2.0 Dynamic Client Registration according to RFC 7591. Uses multi-level fallback approach: - + 1. Explicit parameter (highest priority) 2. Provider's initial_access_token() method 3. OAUTH_INITIAL_ACCESS_TOKEN environment variable @@ -541,7 +541,7 @@ async def initial_access_token(self) -> str | None: # Return constructor parameter if available if self._initial_access_token: return self._initial_access_token - + # Subclasses can override this method to provide tokens from other sources return None diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 7d5965ce3..649673f0e 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -430,7 +430,7 @@ async def test_register_client_with_explicit_initial_access_token(self, oauth_pr @pytest.mark.anyio async def test_register_client_with_provider_initial_access_token(self, client_metadata, mock_storage): """Test client registration with provider method initial access token.""" - + class CustomOAuthProvider(OAuthClientProvider): async def initial_access_token(self) -> str | None: return "provider-token" @@ -460,7 +460,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.mark.anyio async def test_register_client_explicit_overrides_provider(self, client_metadata, mock_storage): """Test explicit initial access token overrides provider method.""" - + class CustomOAuthProvider(OAuthClientProvider): async def initial_access_token(self) -> str | None: return "provider-token" @@ -511,7 +511,7 @@ async def test_register_client_without_initial_access_token(self, oauth_provider @pytest.mark.anyio async def test_initial_access_token_constructor_parameter(self, client_metadata, mock_storage): """Test OAuthClientProvider with initial access token constructor parameter.""" - + async def redirect_handler(url: str) -> None: pass From 0d191a286d4cbf833a099c80309974afcba6f284 Mon Sep 17 00:00:00 2001 From: Andor Markus Date: Wed, 16 Jul 2025 13:00:52 +0200 Subject: [PATCH 3/3] feat: add initial access token support for OAuth 2.0 Dynamic Client Registration (RFC 7591) - Add initial_access_token parameter to OAuthClientProvider constructor - Implement multi-level fallback for token resolution: 1. Explicit parameter (highest priority) 2. Provider method (initial_access_token()) 3. Environment variable (OAUTH_INITIAL_ACCESS_TOKEN) 4. No token (existing behavior) - Add Authorization Bearer header to registration requests when token available - Add comprehensive test coverage for all fallback scenarios - Update documentation with usage examples and configuration details - Maintain full backward compatibility with existing OAuth flows This enables clients to register with protected OAuth endpoints that require initial access tokens per RFC 7591 Dynamic Client Registration specification. --- diff_with_main.txt | 577 --------------------------------------------- 1 file changed, 577 deletions(-) delete mode 100644 diff_with_main.txt diff --git a/diff_with_main.txt b/diff_with_main.txt deleted file mode 100644 index d91174b88..000000000 --- a/diff_with_main.txt +++ /dev/null @@ -1,577 +0,0 @@ -diff --git a/package-lock.json b/package-lock.json -index 01bc095..fa1bde0 100644 ---- a/package-lock.json -+++ b/package-lock.json -@@ -1,12 +1,12 @@ - { - "name": "@modelcontextprotocol/sdk", -- "version": "1.15.0", -+ "version": "1.15.1", - "lockfileVersion": 3, - "requires": true, - "packages": { - "": { - "name": "@modelcontextprotocol/sdk", -- "version": "1.15.0", -+ "version": "1.15.1", - "license": "MIT", - "dependencies": { - "ajv": "^6.12.6", -diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts -index ce0cc70..eb26abc 100644 ---- a/src/client/auth.test.ts -+++ b/src/client/auth.test.ts -@@ -1158,6 +1158,140 @@ describe("OAuth Authorization", () => { - }) - ).rejects.toThrow("Dynamic client registration failed"); - }); -+ -+ describe("initial access token support", () => { -+ it("includes initial access token from explicit parameter", async () => { -+ mockFetch.mockResolvedValueOnce({ -+ ok: true, -+ status: 200, -+ json: async () => validClientInfo, -+ }); -+ -+ await registerClient("https://auth.example.com", { -+ clientMetadata: validClientMetadata, -+ initialAccessToken: "explicit-token", -+ }); -+ -+ expect(mockFetch).toHaveBeenCalledWith( -+ expect.objectContaining({ -+ href: "https://auth.example.com/register", -+ }), -+ expect.objectContaining({ -+ method: "POST", -+ headers: { -+ "Content-Type": "application/json", -+ "Authorization": "Bearer explicit-token", -+ }, -+ body: JSON.stringify(validClientMetadata), -+ }) -+ ); -+ }); -+ -+ it("includes initial access token from provider method", async () => { -+ const mockProvider: OAuthClientProvider = { -+ get redirectUrl() { return "http://localhost:3000/callback"; }, -+ get clientMetadata() { return validClientMetadata; }, -+ clientInformation: jest.fn(), -+ tokens: jest.fn(), -+ saveTokens: jest.fn(), -+ redirectToAuthorization: jest.fn(), -+ saveCodeVerifier: jest.fn(), -+ codeVerifier: jest.fn(), -+ initialAccessToken: jest.fn().mockResolvedValue("provider-token"), -+ }; -+ -+ mockFetch.mockResolvedValueOnce({ -+ ok: true, -+ status: 200, -+ json: async () => validClientInfo, -+ }); -+ -+ await registerClient("https://auth.example.com", { -+ clientMetadata: validClientMetadata, -+ provider: mockProvider, -+ }); -+ -+ expect(mockFetch).toHaveBeenCalledWith( -+ expect.objectContaining({ -+ href: "https://auth.example.com/register", -+ }), -+ expect.objectContaining({ -+ method: "POST", -+ headers: { -+ "Content-Type": "application/json", -+ "Authorization": "Bearer provider-token", -+ }, -+ body: JSON.stringify(validClientMetadata), -+ }) -+ ); -+ }); -+ -+ it("prioritizes explicit parameter over provider method", async () => { -+ const mockProvider: OAuthClientProvider = { -+ get redirectUrl() { return "http://localhost:3000/callback"; }, -+ get clientMetadata() { return validClientMetadata; }, -+ clientInformation: jest.fn(), -+ tokens: jest.fn(), -+ saveTokens: jest.fn(), -+ redirectToAuthorization: jest.fn(), -+ saveCodeVerifier: jest.fn(), -+ codeVerifier: jest.fn(), -+ initialAccessToken: jest.fn().mockResolvedValue("provider-token"), -+ }; -+ -+ mockFetch.mockResolvedValueOnce({ -+ ok: true, -+ status: 200, -+ json: async () => validClientInfo, -+ }); -+ -+ await registerClient("https://auth.example.com", { -+ clientMetadata: validClientMetadata, -+ initialAccessToken: "explicit-token", -+ provider: mockProvider, -+ }); -+ -+ expect(mockProvider.initialAccessToken).not.toHaveBeenCalled(); -+ expect(mockFetch).toHaveBeenCalledWith( -+ expect.objectContaining({ -+ href: "https://auth.example.com/register", -+ }), -+ expect.objectContaining({ -+ method: "POST", -+ headers: { -+ "Content-Type": "application/json", -+ "Authorization": "Bearer explicit-token", -+ }, -+ body: JSON.stringify(validClientMetadata), -+ }) -+ ); -+ }); -+ -+ it("registers without authorization header when no token available", async () => { -+ mockFetch.mockResolvedValueOnce({ -+ ok: true, -+ status: 200, -+ json: async () => validClientInfo, -+ }); -+ -+ await registerClient("https://auth.example.com", { -+ clientMetadata: validClientMetadata, -+ }); -+ -+ expect(mockFetch).toHaveBeenCalledWith( -+ expect.objectContaining({ -+ href: "https://auth.example.com/register", -+ }), -+ expect.objectContaining({ -+ method: "POST", -+ headers: { -+ "Content-Type": "application/json", -+ }, -+ body: JSON.stringify(validClientMetadata), -+ }) -+ ); -+ }); -+ }); - }); - - describe("auth function", () => { -diff --git a/src/client/auth.ts b/src/client/auth.ts -index 4a8bbe2..a3e937c 100644 ---- a/src/client/auth.ts -+++ b/src/client/auth.ts -@@ -124,6 +124,17 @@ export interface OAuthClientProvider { - * This avoids requiring the user to intervene manually. - */ - invalidateCredentials?(scope: 'all' | 'client' | 'tokens' | 'verifier'): void | Promise; -+ -+ /** -+ * If implemented, provides an initial access token for OAuth 2.0 Dynamic Client Registration -+ * according to RFC 7591. This token is used to authorize the client registration request. -+ * -+ * The initial access token allows the client to register with authorization servers that -+ * require pre-authorization for dynamic client registration. -+ * -+ * @returns The initial access token string, or undefined if none is available -+ */ -+ initialAccessToken?(): string | undefined | Promise; - } - - export type AuthResult = "AUTHORIZED" | "REDIRECT"; -@@ -281,7 +292,8 @@ export async function auth( - serverUrl: string | URL; - authorizationCode?: string; - scope?: string; -- resourceMetadataUrl?: URL }): Promise { -+ resourceMetadataUrl?: URL; -+ initialAccessToken?: string; }): Promise { - - try { - return await authInternal(provider, options); -@@ -305,12 +317,14 @@ async function authInternal( - { serverUrl, - authorizationCode, - scope, -- resourceMetadataUrl -+ resourceMetadataUrl, -+ initialAccessToken - }: { - serverUrl: string | URL; - authorizationCode?: string; - scope?: string; -- resourceMetadataUrl?: URL -+ resourceMetadataUrl?: URL; -+ initialAccessToken?: string; - }): Promise { - - let resourceMetadata: OAuthProtectedResourceMetadata | undefined; -@@ -344,6 +358,8 @@ async function authInternal( - const fullInformation = await registerClient(authorizationServerUrl, { - metadata, - clientMetadata: provider.clientMetadata, -+ initialAccessToken, -+ provider, - }); - - await provider.saveClientInformation(fullInformation); -@@ -877,15 +893,28 @@ export async function refreshAuthorization( - - /** - * Performs OAuth 2.0 Dynamic Client Registration according to RFC 7591. -+ * -+ * Supports initial access tokens for authorization servers that require -+ * pre-authorization for dynamic client registration. The initial access token -+ * is resolved using a multi-level fallback approach: -+ * -+ * 1. Explicit `initialAccessToken` parameter (highest priority) -+ * 2. Provider's `initialAccessToken()` method (if implemented) -+ * 3. `OAUTH_INITIAL_ACCESS_TOKEN` environment variable -+ * 4. None (current behavior for servers that don't require pre-authorization) - */ - export async function registerClient( - authorizationServerUrl: string | URL, - { - metadata, - clientMetadata, -+ initialAccessToken, -+ provider, - }: { - metadata?: OAuthMetadata; - clientMetadata: OAuthClientMetadata; -+ initialAccessToken?: string; -+ provider?: OAuthClientProvider; - }, - ): Promise { - let registrationUrl: URL; -@@ -900,11 +929,33 @@ export async function registerClient( - registrationUrl = new URL("/register", authorizationServerUrl); - } - -+ // Multi-level fallback for initial access token -+ let token = initialAccessToken; // Level 1: Explicit parameter -+ -+ if (!token && provider?.initialAccessToken) { -+ // Level 2: Provider method -+ token = await Promise.resolve(provider.initialAccessToken()); -+ } -+ -+ // Level 3: Environment variable (Node.js environments only) -+ if (!token && typeof globalThis !== 'undefined' && (globalThis as any).process?.env) { -+ token = (globalThis as any).process.env.OAUTH_INITIAL_ACCESS_TOKEN; -+ } -+ -+ // Level 4: None (current behavior) - no token needed -+ -+ const headers: Record = { -+ "Content-Type": "application/json", -+ }; -+ -+ // Add initial access token if available (RFC 7591) -+ if (token) { -+ headers["Authorization"] = `Bearer ${token}`; -+ } -+ - const response = await fetch(registrationUrl, { - method: "POST", -- headers: { -- "Content-Type": "application/json", -- }, -+ headers, - body: JSON.stringify(clientMetadata), - }); - -diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts -index 2cc4a1d..d8cadfb 100644 ---- a/src/client/sse.test.ts -+++ b/src/client/sse.test.ts -@@ -1107,5 +1107,80 @@ describe("SSEClientTransport", () => { - await expect(() => transport.start()).rejects.toThrow(InvalidGrantError); - expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('tokens'); - }); -+ -+ describe("initialAccessToken support", () => { -+ it("stores initialAccessToken from constructor options", () => { -+ const transport = new SSEClientTransport( -+ new URL("http://localhost:1234/mcp"), -+ { initialAccessToken: "test-initial-token" } -+ ); -+ -+ // Access private property for testing -+ const transportInstance = transport as unknown as { _initialAccessToken?: string }; -+ expect(transportInstance._initialAccessToken).toBe("test-initial-token"); -+ }); -+ -+ it("works without initialAccessToken (backward compatibility)", async () => { -+ const transport = new SSEClientTransport( -+ new URL("http://localhost:1234/mcp"), -+ { authProvider: mockAuthProvider } -+ ); -+ -+ const transportInstance = transport as unknown as { _initialAccessToken?: string }; -+ expect(transportInstance._initialAccessToken).toBeUndefined(); -+ -+ // Should not throw when no initial access token provided -+ expect(() => transport).not.toThrow(); -+ }); -+ -+ it("includes initialAccessToken in auth calls", async () => { -+ // Create a spy on the auth module -+ const authModule = await import("./auth.js"); -+ const authSpy = jest.spyOn(authModule, "auth").mockResolvedValue("REDIRECT"); -+ -+ const transport = new SSEClientTransport( -+ resourceBaseUrl, -+ { -+ authProvider: mockAuthProvider, -+ initialAccessToken: "test-initial-token" -+ } -+ ); -+ -+ // Start the transport first -+ await transport.start(); -+ -+ // Mock fetch to return 401 and trigger auth on send -+ const originalFetch = global.fetch; -+ global.fetch = jest.fn().mockResolvedValueOnce({ -+ ok: false, -+ status: 401, -+ headers: new Headers(), -+ }); -+ -+ const message = { -+ jsonrpc: "2.0" as const, -+ method: "test", -+ params: {}, -+ id: "test-id" -+ }; -+ -+ try { -+ await transport.send(message); -+ } catch { -+ // Expected to fail due to mock setup, we're just testing auth call -+ } -+ -+ expect(authSpy).toHaveBeenCalledWith( -+ mockAuthProvider, -+ expect.objectContaining({ -+ initialAccessToken: "test-initial-token" -+ }) -+ ); -+ -+ // Restore fetch and spy -+ global.fetch = originalFetch; -+ authSpy.mockRestore(); -+ }); -+ }); - }); - }); -diff --git a/src/client/sse.ts b/src/client/sse.ts -index 568a515..98484bf 100644 ---- a/src/client/sse.ts -+++ b/src/client/sse.ts -@@ -52,6 +52,16 @@ export type SSEClientTransportOptions = { - * Custom fetch implementation used for all network requests. - */ - fetch?: FetchLike; -+ -+ /** -+ * Initial access token for OAuth 2.0 Dynamic Client Registration (RFC 7591). -+ * This token is used to authorize the client registration request with authorization servers -+ * that require pre-authorization for dynamic client registration. -+ * -+ * If not provided, the system will fall back to the provider's `initialAccessToken()` method -+ * and then to the `OAUTH_INITIAL_ACCESS_TOKEN` environment variable. -+ */ -+ initialAccessToken?: string; - }; - - /** -@@ -69,6 +79,7 @@ export class SSEClientTransport implements Transport { - private _authProvider?: OAuthClientProvider; - private _fetch?: FetchLike; - private _protocolVersion?: string; -+ private _initialAccessToken?: string; - - onclose?: () => void; - onerror?: (error: Error) => void; -@@ -84,6 +95,7 @@ export class SSEClientTransport implements Transport { - this._requestInit = opts?.requestInit; - this._authProvider = opts?.authProvider; - this._fetch = opts?.fetch; -+ this._initialAccessToken = opts?.initialAccessToken; - } - - private async _authThenStart(): Promise { -@@ -93,7 +105,7 @@ export class SSEClientTransport implements Transport { - - let result: AuthResult; - try { -- result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); -+ result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, initialAccessToken: this._initialAccessToken }); - } catch (error) { - this.onerror?.(error as Error); - throw error; -@@ -218,7 +230,7 @@ export class SSEClientTransport implements Transport { - throw new UnauthorizedError("No auth provider"); - } - -- const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl }); -+ const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, initialAccessToken: this._initialAccessToken }); - if (result !== "AUTHORIZED") { - throw new UnauthorizedError("Failed to authorize"); - } -@@ -252,7 +264,7 @@ const response = await (this._fetch ?? fetch)(this._endpoint, init); - - this._resourceMetadataUrl = extractResourceMetadataUrl(response); - -- const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); -+ const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, initialAccessToken: this._initialAccessToken }); - if (result !== "AUTHORIZED") { - throw new UnauthorizedError(); - } -diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts -index c54cf28..baeb955 100644 ---- a/src/client/streamableHttp.test.ts -+++ b/src/client/streamableHttp.test.ts -@@ -855,4 +855,73 @@ describe("StreamableHTTPClientTransport", () => { - await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); - expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('tokens'); - }); -+ -+ describe("initialAccessToken support", () => { -+ it("stores initialAccessToken from constructor options", () => { -+ const transport = new StreamableHTTPClientTransport( -+ new URL("http://localhost:1234/mcp"), -+ { initialAccessToken: "test-initial-token" } -+ ); -+ -+ // Access private property for testing -+ const transportInstance = transport as unknown as { _initialAccessToken?: string }; -+ expect(transportInstance._initialAccessToken).toBe("test-initial-token"); -+ }); -+ -+ it("works without initialAccessToken (backward compatibility)", async () => { -+ const transport = new StreamableHTTPClientTransport( -+ new URL("http://localhost:1234/mcp"), -+ { authProvider: mockAuthProvider } -+ ); -+ -+ const transportInstance = transport as unknown as { _initialAccessToken?: string }; -+ expect(transportInstance._initialAccessToken).toBeUndefined(); -+ -+ // Should not throw when no initial access token provided -+ expect(() => transport).not.toThrow(); -+ }); -+ -+ it("includes initialAccessToken in auth calls", async () => { -+ // Create a spy on the auth module -+ const authModule = await import("./auth.js"); -+ const authSpy = jest.spyOn(authModule, "auth").mockResolvedValue("REDIRECT"); -+ -+ const transport = new StreamableHTTPClientTransport( -+ new URL("http://localhost:1234/mcp"), -+ { -+ authProvider: mockAuthProvider, -+ initialAccessToken: "test-initial-token" -+ } -+ ); -+ -+ // Mock fetch to trigger auth flow on send (401 response) -+ (global.fetch as jest.Mock).mockResolvedValueOnce({ -+ ok: false, -+ status: 401, -+ headers: new Headers(), -+ }); -+ -+ const message = { -+ jsonrpc: "2.0" as const, -+ method: "test", -+ params: {}, -+ id: "test-id" -+ }; -+ -+ try { -+ await transport.send(message); -+ } catch { -+ // Expected to fail due to mock setup, we're just testing auth call -+ } -+ -+ expect(authSpy).toHaveBeenCalledWith( -+ mockAuthProvider, -+ expect.objectContaining({ -+ initialAccessToken: "test-initial-token" -+ }) -+ ); -+ -+ authSpy.mockRestore(); -+ }); -+ }); - }); -diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts -index b0894fc..a790372 100644 ---- a/src/client/streamableHttp.ts -+++ b/src/client/streamableHttp.ts -@@ -114,6 +114,14 @@ export type StreamableHTTPClientTransportOptions = { - * When not provided and connecting to a server that supports session IDs, the server will generate a new session ID. - */ - sessionId?: string; -+ -+ /** -+ * Initial access token for OAuth 2.0 Dynamic Client Registration (RFC 7591). -+ * This token is used to authorize the client registration request with authorization servers that require pre-authorization for dynamic client registration. -+ * -+ * If not provided, the system will fall back to the provider's `initialAccessToken()` method and then to the `OAUTH_INITIAL_ACCESS_TOKEN` environment variable. -+ */ -+ initialAccessToken?: string; - }; - - /** -@@ -131,6 +139,7 @@ export class StreamableHTTPClientTransport implements Transport { - private _sessionId?: string; - private _reconnectionOptions: StreamableHTTPReconnectionOptions; - private _protocolVersion?: string; -+ private _initialAccessToken?: string; - - onclose?: () => void; - onerror?: (error: Error) => void; -@@ -147,6 +156,7 @@ export class StreamableHTTPClientTransport implements Transport { - this._fetch = opts?.fetch; - this._sessionId = opts?.sessionId; - this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; -+ this._initialAccessToken = opts?.initialAccessToken; - } - - private async _authThenStart(): Promise { -@@ -156,7 +166,7 @@ export class StreamableHTTPClientTransport implements Transport { - - let result: AuthResult; - try { -- result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); -+ result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, initialAccessToken: this._initialAccessToken }); - } catch (error) { - this.onerror?.(error as Error); - throw error; -@@ -392,7 +402,7 @@ const response = await (this._fetch ?? fetch)(this._url, { - throw new UnauthorizedError("No auth provider"); - } - -- const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl }); -+ const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, initialAccessToken: this._initialAccessToken }); - if (result !== "AUTHORIZED") { - throw new UnauthorizedError("Failed to authorize"); - } -@@ -440,7 +450,7 @@ const response = await (this._fetch ?? fetch)(this._url, init); - - this._resourceMetadataUrl = extractResourceMetadataUrl(response); - -- const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); -+ const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, initialAccessToken: this._initialAccessToken }); - if (result !== "AUTHORIZED") { - throw new UnauthorizedError(); - }