Skip to content

Commit 99c4f3c

Browse files
Support falling back to OIDC metadata for auth (#1061)
1 parent 41184ba commit 99c4f3c

File tree

3 files changed

+100
-202
lines changed

3 files changed

+100
-202
lines changed

src/mcp/client/auth.py

Lines changed: 46 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -251,72 +251,32 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
251251
except ValidationError:
252252
pass
253253

254-
def _build_well_known_path(self, pathname: str) -> str:
255-
"""Construct well-known path for OAuth metadata discovery."""
256-
well_known_path = f"/.well-known/oauth-authorization-server{pathname}"
257-
if pathname.endswith("/"):
258-
# Strip trailing slash from pathname to avoid double slashes
259-
well_known_path = well_known_path[:-1]
260-
return well_known_path
261-
262-
def _should_attempt_fallback(self, response_status: int, pathname: str) -> bool:
263-
"""Determine if fallback to root discovery should be attempted."""
264-
return response_status == 404 and pathname != "/"
265-
266-
async def _try_metadata_discovery(self, url: str) -> httpx.Request:
267-
"""Build metadata discovery request for a specific URL."""
268-
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
269-
270-
async def _discover_oauth_metadata(self) -> httpx.Request:
271-
"""Build OAuth metadata discovery request with fallback support."""
272-
if self.context.auth_server_url:
273-
auth_server_url = self.context.auth_server_url
274-
else:
275-
auth_server_url = self.context.server_url
276-
277-
# Per RFC 8414, try path-aware discovery first
254+
def _get_discovery_urls(self) -> list[str]:
255+
"""Generate ordered list of (url, type) tuples for discovery attempts."""
256+
urls: list[str] = []
257+
auth_server_url = self.context.auth_server_url or self.context.server_url
278258
parsed = urlparse(auth_server_url)
279-
well_known_path = self._build_well_known_path(parsed.path)
280259
base_url = f"{parsed.scheme}://{parsed.netloc}"
281-
url = urljoin(base_url, well_known_path)
282260

283-
# Store fallback info for use in response handler
284-
self.context.discovery_base_url = base_url
285-
self.context.discovery_pathname = parsed.path
261+
# RFC 8414: Path-aware OAuth discovery
262+
if parsed.path and parsed.path != "/":
263+
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
264+
urls.append(urljoin(base_url, oauth_path))
286265

287-
return await self._try_metadata_discovery(url)
266+
# OAuth root fallback
267+
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
288268

289-
async def _discover_oauth_metadata_fallback(self) -> httpx.Request:
290-
"""Build fallback OAuth metadata discovery request for legacy servers."""
291-
base_url = getattr(self.context, "discovery_base_url", "")
292-
if not base_url:
293-
raise OAuthFlowError("No base URL available for fallback discovery")
294-
295-
# Fallback to root discovery for legacy servers
296-
url = urljoin(base_url, "/.well-known/oauth-authorization-server")
297-
return await self._try_metadata_discovery(url)
298-
299-
async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fallback: bool = False) -> bool:
300-
"""Handle OAuth metadata response. Returns True if handled successfully."""
301-
if response.status_code == 200:
302-
try:
303-
content = await response.aread()
304-
metadata = OAuthMetadata.model_validate_json(content)
305-
self.context.oauth_metadata = metadata
306-
# Apply default scope if none specified
307-
if self.context.client_metadata.scope is None and metadata.scopes_supported is not None:
308-
self.context.client_metadata.scope = " ".join(metadata.scopes_supported)
309-
return True
310-
except ValidationError:
311-
pass
269+
# RFC 8414 section 5: Path-aware OIDC discovery
270+
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
271+
if parsed.path and parsed.path != "/":
272+
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
273+
urls.append(urljoin(base_url, oidc_path))
312274

313-
# Check if we should attempt fallback (404 on path-aware discovery)
314-
if not is_fallback and self._should_attempt_fallback(
315-
response.status_code, getattr(self.context, "discovery_pathname", "/")
316-
):
317-
return False # Signal that fallback should be attempted
275+
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
276+
oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration"
277+
urls.append(oidc_fallback)
318278

319-
return True # Signal no fallback needed (either success or non-404 error)
279+
return urls
320280

321281
async def _register_client(self) -> httpx.Request | None:
322282
"""Build registration request or skip if already registered."""
@@ -511,6 +471,17 @@ def _add_auth_header(self, request: httpx.Request) -> None:
511471
if self.context.current_tokens and self.context.current_tokens.access_token:
512472
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
513473

474+
def _create_oauth_metadata_request(self, url: str) -> httpx.Request:
475+
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
476+
477+
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
478+
content = await response.aread()
479+
metadata = OAuthMetadata.model_validate_json(content)
480+
self.context.oauth_metadata = metadata
481+
# Apply default scope if needed
482+
if self.context.client_metadata.scope is None and metadata.scopes_supported is not None:
483+
self.context.client_metadata.scope = " ".join(metadata.scopes_supported)
484+
514485
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
515486
"""HTTPX auth flow integration."""
516487
async with self.context.lock:
@@ -544,15 +515,19 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
544515
await self._handle_protected_resource_response(discovery_response)
545516

546517
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
547-
oauth_request = await self._discover_oauth_metadata()
548-
oauth_response = yield oauth_request
549-
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)
550-
551-
# If path-aware discovery failed with 404, try fallback to root
552-
if not handled:
553-
fallback_request = await self._discover_oauth_metadata_fallback()
554-
fallback_response = yield fallback_request
555-
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
518+
discovery_urls = self._get_discovery_urls()
519+
for url in discovery_urls:
520+
request = self._create_oauth_metadata_request(url)
521+
response = yield request
522+
523+
if response.status_code == 200:
524+
try:
525+
await self._handle_oauth_metadata_response(response)
526+
break
527+
except ValidationError:
528+
continue
529+
elif response.status_code != 404:
530+
break # Non-404 error, stop trying
556531

557532
# Step 3: Register client if needed
558533
registration_request = await self._register_client()
@@ -571,6 +546,6 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
571546
logger.exception("OAuth flow error")
572547
raise
573548

574-
# Retry with new tokens
575-
self._add_auth_header(request)
576-
yield request
549+
# Retry with new tokens
550+
self._add_auth_header(request)
551+
yield request

tests/client/test_auth.py

Lines changed: 15 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -235,107 +235,30 @@ async def callback_handler() -> tuple[str, str | None]:
235235
assert "mcp-protocol-version" in request.headers
236236

237237
@pytest.mark.anyio
238-
async def test_discover_oauth_metadata_request(self, oauth_provider):
238+
def test_create_oauth_metadata_request(self, oauth_provider):
239239
"""Test OAuth metadata discovery request building."""
240-
request = await oauth_provider._discover_oauth_metadata()
240+
request = oauth_provider._create_oauth_metadata_request("https://example.com")
241241

242+
# Ensure correct method and headers, and that the URL is unmodified
242243
assert request.method == "GET"
243-
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
244-
assert "mcp-protocol-version" in request.headers
245-
246-
@pytest.mark.anyio
247-
async def test_discover_oauth_metadata_request_no_path(self, client_metadata, mock_storage):
248-
"""Test OAuth metadata discovery request building when server has no path."""
249-
250-
async def redirect_handler(url: str) -> None:
251-
pass
252-
253-
async def callback_handler() -> tuple[str, str | None]:
254-
return "test_auth_code", "test_state"
255-
256-
provider = OAuthClientProvider(
257-
server_url="https://api.example.com",
258-
client_metadata=client_metadata,
259-
storage=mock_storage,
260-
redirect_handler=redirect_handler,
261-
callback_handler=callback_handler,
262-
)
263-
264-
request = await provider._discover_oauth_metadata()
265-
266-
assert request.method == "GET"
267-
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server"
268-
assert "mcp-protocol-version" in request.headers
269-
270-
@pytest.mark.anyio
271-
async def test_discover_oauth_metadata_request_trailing_slash(self, client_metadata, mock_storage):
272-
"""Test OAuth metadata discovery request building when server path has trailing slash."""
273-
274-
async def redirect_handler(url: str) -> None:
275-
pass
276-
277-
async def callback_handler() -> tuple[str, str | None]:
278-
return "test_auth_code", "test_state"
279-
280-
provider = OAuthClientProvider(
281-
server_url="https://api.example.com/v1/mcp/",
282-
client_metadata=client_metadata,
283-
storage=mock_storage,
284-
redirect_handler=redirect_handler,
285-
callback_handler=callback_handler,
286-
)
287-
288-
request = await provider._discover_oauth_metadata()
289-
290-
assert request.method == "GET"
291-
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
244+
assert str(request.url) == "https://example.com"
292245
assert "mcp-protocol-version" in request.headers
293246

294247

295248
class TestOAuthFallback:
296249
"""Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers."""
297250

298251
@pytest.mark.anyio
299-
async def test_fallback_discovery_request(self, client_metadata, mock_storage):
300-
"""Test fallback discovery request building."""
301-
302-
async def redirect_handler(url: str) -> None:
303-
pass
304-
305-
async def callback_handler() -> tuple[str, str | None]:
306-
return "test_auth_code", "test_state"
307-
308-
provider = OAuthClientProvider(
309-
server_url="https://api.example.com/v1/mcp",
310-
client_metadata=client_metadata,
311-
storage=mock_storage,
312-
redirect_handler=redirect_handler,
313-
callback_handler=callback_handler,
314-
)
315-
316-
# Set up discovery state manually as if path-aware discovery was attempted
317-
provider.context.discovery_base_url = "https://api.example.com"
318-
provider.context.discovery_pathname = "/v1/mcp"
252+
async def test_oauth_discovery_fallback_order(self, oauth_provider):
253+
"""Test fallback URL construction order."""
254+
discovery_urls = oauth_provider._get_discovery_urls()
319255

320-
# Test fallback request building
321-
request = await provider._discover_oauth_metadata_fallback()
322-
323-
assert request.method == "GET"
324-
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server"
325-
assert "mcp-protocol-version" in request.headers
326-
327-
@pytest.mark.anyio
328-
async def test_should_attempt_fallback(self, oauth_provider):
329-
"""Test fallback decision logic."""
330-
# Should attempt fallback on 404 with non-root path
331-
assert oauth_provider._should_attempt_fallback(404, "/v1/mcp")
332-
333-
# Should NOT attempt fallback on 404 with root path
334-
assert not oauth_provider._should_attempt_fallback(404, "/")
335-
336-
# Should NOT attempt fallback on other status codes
337-
assert not oauth_provider._should_attempt_fallback(200, "/v1/mcp")
338-
assert not oauth_provider._should_attempt_fallback(500, "/v1/mcp")
256+
assert discovery_urls == [
257+
"https://api.example.com/.well-known/oauth-authorization-server/v1/mcp",
258+
"https://api.example.com/.well-known/oauth-authorization-server",
259+
"https://api.example.com/.well-known/openid-configuration/v1/mcp",
260+
"https://api.example.com/v1/mcp/.well-known/openid-configuration",
261+
]
339262

340263
@pytest.mark.anyio
341264
async def test_handle_metadata_response_success(self, oauth_provider):
@@ -348,50 +271,11 @@ async def test_handle_metadata_response_success(self, oauth_provider):
348271
}"""
349272
response = httpx.Response(200, content=content)
350273

351-
# Should return True (success) and set metadata
352-
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False)
353-
assert result is True
274+
# Should set metadata
275+
await oauth_provider._handle_oauth_metadata_response(response)
354276
assert oauth_provider.context.oauth_metadata is not None
355277
assert str(oauth_provider.context.oauth_metadata.issuer) == "https://auth.example.com/"
356278

357-
@pytest.mark.anyio
358-
async def test_handle_metadata_response_404_needs_fallback(self, oauth_provider):
359-
"""Test 404 response handling that should trigger fallback."""
360-
# Set up discovery state for non-root path
361-
oauth_provider.context.discovery_base_url = "https://api.example.com"
362-
oauth_provider.context.discovery_pathname = "/v1/mcp"
363-
364-
# Mock 404 response
365-
response = httpx.Response(404)
366-
367-
# Should return False (needs fallback)
368-
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False)
369-
assert result is False
370-
371-
@pytest.mark.anyio
372-
async def test_handle_metadata_response_404_no_fallback_needed(self, oauth_provider):
373-
"""Test 404 response handling when no fallback is needed."""
374-
# Set up discovery state for root path
375-
oauth_provider.context.discovery_base_url = "https://api.example.com"
376-
oauth_provider.context.discovery_pathname = "/"
377-
378-
# Mock 404 response
379-
response = httpx.Response(404)
380-
381-
# Should return True (no fallback needed)
382-
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False)
383-
assert result is True
384-
385-
@pytest.mark.anyio
386-
async def test_handle_metadata_response_404_fallback_attempt(self, oauth_provider):
387-
"""Test 404 response handling during fallback attempt."""
388-
# Mock 404 response during fallback
389-
response = httpx.Response(404)
390-
391-
# Should return True (fallback attempt complete, no further action needed)
392-
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=True)
393-
assert result is True
394-
395279
@pytest.mark.anyio
396280
async def test_register_client_request(self, oauth_provider):
397281
"""Test client registration request building."""

tests/shared/test_auth.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Tests for OAuth 2.0 shared code."""
2+
3+
from mcp.shared.auth import OAuthMetadata
4+
5+
6+
class TestOAuthMetadata:
7+
"""Tests for OAuthMetadata parsing."""
8+
9+
def test_oauth(self):
10+
"""Should not throw when parsing OAuth metadata."""
11+
OAuthMetadata.model_validate(
12+
{
13+
"issuer": "https://example.com",
14+
"authorization_endpoint": "https://example.com/oauth2/authorize",
15+
"token_endpoint": "https://example.com/oauth2/token",
16+
"scopes_supported": ["read", "write"],
17+
"response_types_supported": ["code", "token"],
18+
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
19+
}
20+
)
21+
22+
def test_oidc(self):
23+
"""Should not throw when parsing OIDC metadata."""
24+
OAuthMetadata.model_validate(
25+
{
26+
"issuer": "https://example.com",
27+
"authorization_endpoint": "https://example.com/oauth2/authorize",
28+
"token_endpoint": "https://example.com/oauth2/token",
29+
"end_session_endpoint": "https://example.com/logout",
30+
"id_token_signing_alg_values_supported": ["RS256"],
31+
"jwks_uri": "https://example.com/.well-known/jwks.json",
32+
"response_types_supported": ["code", "token"],
33+
"revocation_endpoint": "https://example.com/oauth2/revoke",
34+
"scopes_supported": ["openid", "read", "write"],
35+
"subject_types_supported": ["public"],
36+
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
37+
"userinfo_endpoint": "https://example.com/oauth2/userInfo",
38+
}
39+
)

0 commit comments

Comments
 (0)