diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 8bafe18eb..891de9c33 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -7,6 +7,7 @@ import base64 import hashlib import logging +import re import secrets import string import time @@ -203,10 +204,39 @@ def __init__( ) self._initialized = False - async def _discover_protected_resource(self) -> httpx.Request: - """Build discovery request for protected resource metadata.""" - auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") + def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None: + """ + Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. + + Returns: + Resource metadata URL if found in WWW-Authenticate header, None otherwise + """ + if not init_response or init_response.status_code != 401: + return None + + www_auth_header = init_response.headers.get("WWW-Authenticate") + if not www_auth_header: + return None + + # Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted) + pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))' + match = re.search(pattern, www_auth_header) + + if match: + # Return quoted value if present, otherwise unquoted value + return match.group(1) or match.group(2) + + return None + + async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request: + # RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response + url = self._extract_resource_metadata_from_www_auth(init_response) + + if not url: + # Fallback to well-known discovery + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) + url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") + return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) async def _handle_protected_resource_response(self, response: httpx.Response) -> None: @@ -490,92 +520,60 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) - # Perform OAuth flow if not authenticated - if not self.context.is_token_valid(): - try: - # OAuth flow must be inline due to generator constraints - # Step 1: Discover protected resource metadata (spec revision 2025-06-18) - discovery_request = await self._discover_protected_resource() - discovery_response = yield discovery_request - await self._handle_protected_resource_response(discovery_response) - - # Step 2: Discover OAuth metadata (with fallback for legacy servers) - oauth_request = await self._discover_oauth_metadata() - oauth_response = yield oauth_request - handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False) - - # If path-aware discovery failed with 404, try fallback to root - if not handled: - fallback_request = await self._discover_oauth_metadata_fallback() - fallback_response = yield fallback_request - await self._handle_oauth_metadata_response(fallback_response, is_fallback=True) - - # Step 3: Register client if needed - registration_request = await self._register_client() - if registration_request: - registration_response = yield registration_request - await self._handle_registration_response(registration_response) - - # Step 4: Perform authorization - auth_code, code_verifier = await self._perform_authorization() - - # Step 5: Exchange authorization code for tokens - token_request = await self._exchange_token(auth_code, code_verifier) - token_response = yield token_request - await self._handle_token_response(token_response) - except Exception: - logger.exception("OAuth flow error") - raise - - # Add authorization header and make request - self._add_auth_header(request) + if self.context.is_token_valid(): + self._add_auth_header(request) + response = yield request - # Handle 401 responses - if response.status_code == 401 and self.context.can_refresh_token(): - # Try to refresh token - refresh_request = await self._refresh_token() - refresh_response = yield refresh_request + if response.status_code == 401: + if self.context.can_refresh_token(): + # Try to refresh token + refresh_request = await self._refresh_token() + refresh_response = yield refresh_request - if await self._handle_refresh_response(refresh_response): - # Retry original request with new token - self._add_auth_header(request) - yield request + if not await self._handle_refresh_response(refresh_response): + # Refresh failed, need full re-authentication + self._initialized = False else: - # Refresh failed, need full re-authentication - self._initialized = False - - # OAuth flow must be inline due to generator constraints - # Step 1: Discover protected resource metadata (spec revision 2025-06-18) - discovery_request = await self._discover_protected_resource() - discovery_response = yield discovery_request - await self._handle_protected_resource_response(discovery_response) - - # Step 2: Discover OAuth metadata (with fallback for legacy servers) - oauth_request = await self._discover_oauth_metadata() - oauth_response = yield oauth_request - handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False) - - # If path-aware discovery failed with 404, try fallback to root - if not handled: - fallback_request = await self._discover_oauth_metadata_fallback() - fallback_response = yield fallback_request - await self._handle_oauth_metadata_response(fallback_response, is_fallback=True) - - # Step 3: Register client if needed - registration_request = await self._register_client() - if registration_request: - registration_response = yield registration_request - await self._handle_registration_response(registration_response) - - # Step 4: Perform authorization - auth_code, code_verifier = await self._perform_authorization() - - # Step 5: Exchange authorization code for tokens - token_request = await self._exchange_token(auth_code, code_verifier) - token_response = yield token_request - await self._handle_token_response(token_response) - - # Retry with new tokens - self._add_auth_header(request) - yield request + self.context.clear_tokens() + + # If we don't have valid tokens after refresh, perform OAuth flow + if not self.context.is_token_valid(): + try: + # OAuth flow must be inline due to generator constraints + # Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support) + discovery_request = await self._discover_protected_resource(response) + discovery_response = yield discovery_request + await self._handle_protected_resource_response(discovery_response) + + # Step 2: Discover OAuth metadata (with fallback for legacy servers) + oauth_request = await self._discover_oauth_metadata() + oauth_response = yield oauth_request + handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False) + + # If path-aware discovery failed with 404, try fallback to root + if not handled: + fallback_request = await self._discover_oauth_metadata_fallback() + fallback_response = yield fallback_request + await self._handle_oauth_metadata_response(fallback_response, is_fallback=True) + + # Step 3: Register client if needed + registration_request = await self._register_client() + if registration_request: + registration_response = yield registration_request + await self._handle_registration_response(registration_response) + + # Step 4: Perform authorization + auth_code, code_verifier = await self._perform_authorization() + + # Step 5: Exchange authorization code for tokens + token_request = await self._exchange_token(auth_code, code_verifier) + token_response = yield token_request + await self._handle_token_response(token_response) + except Exception as e: + logger.exception("OAuth flow error") + raise + + # Retry with new tokens + self._add_auth_header(request) + yield request diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index affcaa276..74253a94f 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -196,14 +196,43 @@ class TestOAuthFlow: """Test OAuth flow methods.""" @pytest.mark.anyio - async def test_discover_protected_resource_request(self, oauth_provider): - """Test protected resource discovery request building.""" - request = await oauth_provider._discover_protected_resource() + async def test_discover_protected_resource_request(self, client_metadata, mock_storage): + """Test protected resource discovery request building maintains backward compatibility.""" + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + # Test without WWW-Authenticate (fallback) + init_response = httpx.Response( + status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com") + ) + + request = await provider._discover_protected_resource(init_response) assert request.method == "GET" assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource" assert "mcp-protocol-version" in request.headers + # Test with WWW-Authenticate header + init_response.headers["WWW-Authenticate"] = ( + 'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"' + ) + + request = await provider._discover_protected_resource(init_response) + assert request.method == "GET" + assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path" + assert "mcp-protocol-version" in request.headers + @pytest.mark.anyio async def test_discover_oauth_metadata_request(self, oauth_provider): """Test OAuth metadata discovery request building.""" @@ -580,3 +609,114 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v await auth_flow.asend(response) except StopAsyncIteration: pass # Expected + + +class TestProtectedResourceWWWAuthenticate: + """Test RFC9728 WWW-Authenticate header parsing functionality for protected resource.""" + + @pytest.mark.parametrize( + "www_auth_header,expected_url", + [ + # Quoted URL + ( + 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', + "https://api.example.com/.well-known/oauth-protected-resource", + ), + # Unquoted URL + ( + "Bearer resource_metadata=https://api.example.com/.well-known/oauth-protected-resource", + "https://api.example.com/.well-known/oauth-protected-resource", + ), + # Complex header with multiple parameters + ( + 'Bearer realm="api", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource", ' + 'error="insufficient_scope"', + "https://api.example.com/.well-known/oauth-protected-resource", + ), + # Different URL format + ('Bearer resource_metadata="https://custom.domain.com/metadata"', "https://custom.domain.com/metadata"), + # With path and query params + ( + 'Bearer resource_metadata="https://api.example.com/auth/metadata?version=1"', + "https://api.example.com/auth/metadata?version=1", + ), + ], + ) + def test_extract_resource_metadata_from_www_auth_valid_cases( + self, client_metadata, mock_storage, www_auth_header, expected_url + ): + """Test extraction of resource_metadata URL from various valid WWW-Authenticate headers.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + init_response = httpx.Response( + status_code=401, + headers={"WWW-Authenticate": www_auth_header}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + result = provider._extract_resource_metadata_from_www_auth(init_response) + assert result == expected_url + + @pytest.mark.parametrize( + "status_code,www_auth_header,description", + [ + # No header + (401, None, "no WWW-Authenticate header"), + # Empty header + (401, "", "empty WWW-Authenticate header"), + # Header without resource_metadata + (401, 'Bearer realm="api", error="insufficient_scope"', "no resource_metadata parameter"), + # Malformed header + (401, "Bearer resource_metadata=", "malformed resource_metadata parameter"), + # Non-401 status code + ( + 200, + 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', + "200 OK response", + ), + ( + 500, + 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', + "500 error response", + ), + ], + ) + def test_extract_resource_metadata_from_www_auth_invalid_cases( + self, client_metadata, mock_storage, status_code, www_auth_header, description + ): + """Test extraction returns None for invalid cases.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {} + init_response = httpx.Response( + status_code=status_code, headers=headers, request=httpx.Request("GET", "https://api.example.com/test") + ) + + result = provider._extract_resource_metadata_from_www_auth(init_response) + assert result is None, f"Should return None for {description}"