Skip to content

Commit eb5146d

Browse files
authored
Implement RFC9728 - Support WWW-Authenticate header by MCP client (#1071)
1 parent bd9885f commit eb5146d

File tree

2 files changed

+195
-60
lines changed

2 files changed

+195
-60
lines changed

src/mcp/client/auth.py

Lines changed: 52 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import base64
88
import hashlib
99
import logging
10+
import re
1011
import secrets
1112
import string
1213
import time
@@ -203,10 +204,39 @@ def __init__(
203204
)
204205
self._initialized = False
205206

206-
async def _discover_protected_resource(self) -> httpx.Request:
207-
"""Build discovery request for protected resource metadata."""
208-
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
209-
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")
207+
def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None:
208+
"""
209+
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
210+
211+
Returns:
212+
Resource metadata URL if found in WWW-Authenticate header, None otherwise
213+
"""
214+
if not init_response or init_response.status_code != 401:
215+
return None
216+
217+
www_auth_header = init_response.headers.get("WWW-Authenticate")
218+
if not www_auth_header:
219+
return None
220+
221+
# Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted)
222+
pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))'
223+
match = re.search(pattern, www_auth_header)
224+
225+
if match:
226+
# Return quoted value if present, otherwise unquoted value
227+
return match.group(1) or match.group(2)
228+
229+
return None
230+
231+
async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request:
232+
# RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response
233+
url = self._extract_resource_metadata_from_www_auth(init_response)
234+
235+
if not url:
236+
# Fallback to well-known discovery
237+
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
238+
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")
239+
210240
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
211241

212242
async def _handle_protected_resource_response(self, response: httpx.Response) -> None:
@@ -490,64 +520,26 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
490520
# Capture protocol version from request headers
491521
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
492522

493-
# Perform OAuth flow if not authenticated
494-
if not self.context.is_token_valid():
495-
try:
496-
# OAuth flow must be inline due to generator constraints
497-
# Step 1: Discover protected resource metadata (spec revision 2025-06-18)
498-
discovery_request = await self._discover_protected_resource()
499-
discovery_response = yield discovery_request
500-
await self._handle_protected_resource_response(discovery_response)
501-
502-
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
503-
oauth_request = await self._discover_oauth_metadata()
504-
oauth_response = yield oauth_request
505-
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)
506-
507-
# If path-aware discovery failed with 404, try fallback to root
508-
if not handled:
509-
fallback_request = await self._discover_oauth_metadata_fallback()
510-
fallback_response = yield fallback_request
511-
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
512-
513-
# Step 3: Register client if needed
514-
registration_request = await self._register_client()
515-
if registration_request:
516-
registration_response = yield registration_request
517-
await self._handle_registration_response(registration_response)
518-
519-
# Step 4: Perform authorization
520-
auth_code, code_verifier = await self._perform_authorization()
521-
522-
# Step 5: Exchange authorization code for tokens
523-
token_request = await self._exchange_token(auth_code, code_verifier)
524-
token_response = yield token_request
525-
await self._handle_token_response(token_response)
526-
except Exception:
527-
logger.exception("OAuth flow error")
528-
raise
529-
530-
# Add authorization header and make request
531-
self._add_auth_header(request)
532-
response = yield request
533-
534-
# Handle 401 responses
535-
if response.status_code == 401 and self.context.can_refresh_token():
523+
if not self.context.is_token_valid() and self.context.can_refresh_token():
536524
# Try to refresh token
537525
refresh_request = await self._refresh_token()
538526
refresh_response = yield refresh_request
539527

540-
if await self._handle_refresh_response(refresh_response):
541-
# Retry original request with new token
542-
self._add_auth_header(request)
543-
yield request
544-
else:
528+
if not await self._handle_refresh_response(refresh_response):
545529
# Refresh failed, need full re-authentication
546530
self._initialized = False
547531

532+
if self.context.is_token_valid():
533+
self._add_auth_header(request)
534+
535+
response = yield request
536+
537+
if response.status_code == 401:
538+
# Perform full OAuth flow
539+
try:
548540
# OAuth flow must be inline due to generator constraints
549-
# Step 1: Discover protected resource metadata (spec revision 2025-06-18)
550-
discovery_request = await self._discover_protected_resource()
541+
# Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support)
542+
discovery_request = await self._discover_protected_resource(response)
551543
discovery_response = yield discovery_request
552544
await self._handle_protected_resource_response(discovery_response)
553545

@@ -575,7 +567,10 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
575567
token_request = await self._exchange_token(auth_code, code_verifier)
576568
token_response = yield token_request
577569
await self._handle_token_response(token_response)
570+
except Exception:
571+
logger.exception("OAuth flow error")
572+
raise
578573

579-
# Retry with new tokens
580-
self._add_auth_header(request)
581-
yield request
574+
# Retry with new tokens
575+
self._add_auth_header(request)
576+
yield request

tests/client/test_auth.py

Lines changed: 143 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,43 @@ class TestOAuthFlow:
197197
"""Test OAuth flow methods."""
198198

199199
@pytest.mark.anyio
200-
async def test_discover_protected_resource_request(self, oauth_provider):
201-
"""Test protected resource discovery request building."""
202-
request = await oauth_provider._discover_protected_resource()
200+
async def test_discover_protected_resource_request(self, client_metadata, mock_storage):
201+
"""Test protected resource discovery request building maintains backward compatibility."""
203202

203+
async def redirect_handler(url: str) -> None:
204+
pass
205+
206+
async def callback_handler() -> tuple[str, str | None]:
207+
return "test_auth_code", "test_state"
208+
209+
provider = OAuthClientProvider(
210+
server_url="https://api.example.com",
211+
client_metadata=client_metadata,
212+
storage=mock_storage,
213+
redirect_handler=redirect_handler,
214+
callback_handler=callback_handler,
215+
)
216+
217+
# Test without WWW-Authenticate (fallback)
218+
init_response = httpx.Response(
219+
status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com")
220+
)
221+
222+
request = await provider._discover_protected_resource(init_response)
204223
assert request.method == "GET"
205224
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
206225
assert "mcp-protocol-version" in request.headers
207226

227+
# Test with WWW-Authenticate header
228+
init_response.headers["WWW-Authenticate"] = (
229+
'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"'
230+
)
231+
232+
request = await provider._discover_protected_resource(init_response)
233+
assert request.method == "GET"
234+
assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path"
235+
assert "mcp-protocol-version" in request.headers
236+
208237
@pytest.mark.anyio
209238
async def test_discover_oauth_metadata_request(self, oauth_provider):
210239
"""Test OAuth metadata discovery request building."""
@@ -660,3 +689,114 @@ def test_build_metadata(
660689
"code_challenge_methods_supported": ["S256"],
661690
}
662691
)
692+
693+
694+
class TestProtectedResourceWWWAuthenticate:
695+
"""Test RFC9728 WWW-Authenticate header parsing functionality for protected resource."""
696+
697+
@pytest.mark.parametrize(
698+
"www_auth_header,expected_url",
699+
[
700+
# Quoted URL
701+
(
702+
'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"',
703+
"https://api.example.com/.well-known/oauth-protected-resource",
704+
),
705+
# Unquoted URL
706+
(
707+
"Bearer resource_metadata=https://api.example.com/.well-known/oauth-protected-resource",
708+
"https://api.example.com/.well-known/oauth-protected-resource",
709+
),
710+
# Complex header with multiple parameters
711+
(
712+
'Bearer realm="api", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource", '
713+
'error="insufficient_scope"',
714+
"https://api.example.com/.well-known/oauth-protected-resource",
715+
),
716+
# Different URL format
717+
('Bearer resource_metadata="https://custom.domain.com/metadata"', "https://custom.domain.com/metadata"),
718+
# With path and query params
719+
(
720+
'Bearer resource_metadata="https://api.example.com/auth/metadata?version=1"',
721+
"https://api.example.com/auth/metadata?version=1",
722+
),
723+
],
724+
)
725+
def test_extract_resource_metadata_from_www_auth_valid_cases(
726+
self, client_metadata, mock_storage, www_auth_header, expected_url
727+
):
728+
"""Test extraction of resource_metadata URL from various valid WWW-Authenticate headers."""
729+
730+
async def redirect_handler(url: str) -> None:
731+
pass
732+
733+
async def callback_handler() -> tuple[str, str | None]:
734+
return "test_auth_code", "test_state"
735+
736+
provider = OAuthClientProvider(
737+
server_url="https://api.example.com/v1/mcp",
738+
client_metadata=client_metadata,
739+
storage=mock_storage,
740+
redirect_handler=redirect_handler,
741+
callback_handler=callback_handler,
742+
)
743+
744+
init_response = httpx.Response(
745+
status_code=401,
746+
headers={"WWW-Authenticate": www_auth_header},
747+
request=httpx.Request("GET", "https://api.example.com/test"),
748+
)
749+
750+
result = provider._extract_resource_metadata_from_www_auth(init_response)
751+
assert result == expected_url
752+
753+
@pytest.mark.parametrize(
754+
"status_code,www_auth_header,description",
755+
[
756+
# No header
757+
(401, None, "no WWW-Authenticate header"),
758+
# Empty header
759+
(401, "", "empty WWW-Authenticate header"),
760+
# Header without resource_metadata
761+
(401, 'Bearer realm="api", error="insufficient_scope"', "no resource_metadata parameter"),
762+
# Malformed header
763+
(401, "Bearer resource_metadata=", "malformed resource_metadata parameter"),
764+
# Non-401 status code
765+
(
766+
200,
767+
'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"',
768+
"200 OK response",
769+
),
770+
(
771+
500,
772+
'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"',
773+
"500 error response",
774+
),
775+
],
776+
)
777+
def test_extract_resource_metadata_from_www_auth_invalid_cases(
778+
self, client_metadata, mock_storage, status_code, www_auth_header, description
779+
):
780+
"""Test extraction returns None for invalid cases."""
781+
782+
async def redirect_handler(url: str) -> None:
783+
pass
784+
785+
async def callback_handler() -> tuple[str, str | None]:
786+
return "test_auth_code", "test_state"
787+
788+
provider = OAuthClientProvider(
789+
server_url="https://api.example.com/v1/mcp",
790+
client_metadata=client_metadata,
791+
storage=mock_storage,
792+
redirect_handler=redirect_handler,
793+
callback_handler=callback_handler,
794+
)
795+
796+
headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {}
797+
init_response = httpx.Response(
798+
status_code=status_code, headers=headers, request=httpx.Request("GET", "https://api.example.com/test")
799+
)
800+
801+
result = provider._extract_resource_metadata_from_www_auth(init_response)
802+
assert result is None, f"Should return None for {description}"

0 commit comments

Comments
 (0)