From 37cdb92db132732e7e014d472c4cdc1872265c60 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 30 Jun 2025 14:17:30 -0700 Subject: [PATCH 1/2] Support falling back to OIDC metadata for auth --- src/mcp/client/auth.py | 140 ++++++++++++++++++++++--------- tests/client/test_auth.py | 172 +++++++++++++++++++++++++++++++++----- 2 files changed, 252 insertions(+), 60 deletions(-) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 769e9b4c8..d0f2d3f5a 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -175,6 +175,9 @@ def should_include_resource_param(self, protocol_version: str | None = None) -> return protocol_version >= "2025-06-18" +OAuthDiscoveryStack = list[Callable[[], Awaitable[httpx.Request]]] + + class OAuthClientProvider(httpx.Auth): """ OAuth2 authentication for httpx. @@ -221,24 +224,52 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> except ValidationError: pass - def _build_well_known_path(self, pathname: str) -> str: + def _build_well_known_path(self, pathname: str, well_known_endpoint: str) -> str: """Construct well-known path for OAuth metadata discovery.""" - well_known_path = f"/.well-known/oauth-authorization-server{pathname}" + well_known_path = f"/.well-known/{well_known_endpoint}{pathname}" if pathname.endswith("/"): # Strip trailing slash from pathname to avoid double slashes well_known_path = well_known_path[:-1] return well_known_path - def _should_attempt_fallback(self, response_status: int, pathname: str) -> bool: - """Determine if fallback to root discovery should be attempted.""" - return response_status == 404 and pathname != "/" + def _build_well_known_fallback_url(self, well_known_endpoint: str) -> str: + """Construct fallback well-known URL for OAuth metadata discovery in legacy servers.""" + base_url = getattr(self.context, "discovery_base_url", "") + if not base_url: + raise OAuthFlowError("No base URL available for fallback discovery") + + # Fallback to root discovery for legacy servers + return urljoin(base_url, f"/.well-known/{well_known_endpoint}") + + def _build_oidc_fallback_path(self, pathname: str, well_known_endpoint: str) -> str: + """Construct fallback well-known path for OIDC metadata discovery in legacy servers.""" + # Strip trailing slash from pathname to avoid double slashes + clean_pathname = pathname[:-1] if pathname.endswith("/") else pathname + # OIDC 1.0 appends the well-known path to the full AS URL + return f"{clean_pathname}/.well-known/{well_known_endpoint}" + + def _build_oidc_fallback_url(self, well_known_endpoint: str) -> str: + """Construct fallback well-known URL for OIDC metadata discovery in legacy servers.""" + if self.context.auth_server_url: + auth_server_url = self.context.auth_server_url + else: + auth_server_url = self.context.server_url + + parsed = urlparse(auth_server_url) + well_known_path = self._build_oidc_fallback_path(parsed.path, well_known_endpoint) + base_url = f"{parsed.scheme}://{parsed.netloc}" + return urljoin(base_url, well_known_path) + + def _should_attempt_fallback(self, response_status: int, discovery_stack: OAuthDiscoveryStack) -> bool: + """Determine if further fallback should be attempted.""" + return response_status == 404 and len(discovery_stack) > 0 async def _try_metadata_discovery(self, url: str) -> httpx.Request: """Build metadata discovery request for a specific URL.""" return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - async def _discover_oauth_metadata(self) -> httpx.Request: - """Build OAuth metadata discovery request with fallback support.""" + async def _discover_well_known_metadata(self, well_known_endpoint: str) -> httpx.Request: + """Build .well-known metadata discovery request with fallback support.""" if self.context.auth_server_url: auth_server_url = self.context.auth_server_url else: @@ -246,7 +277,7 @@ async def _discover_oauth_metadata(self) -> httpx.Request: # Per RFC 8414, try path-aware discovery first parsed = urlparse(auth_server_url) - well_known_path = self._build_well_known_path(parsed.path) + well_known_path = self._build_well_known_path(parsed.path, well_known_endpoint) base_url = f"{parsed.scheme}://{parsed.netloc}" url = urljoin(base_url, well_known_path) @@ -256,17 +287,37 @@ async def _discover_oauth_metadata(self) -> httpx.Request: return await self._try_metadata_discovery(url) + async def _discover_well_known_metadata_fallback(self, well_known_endpoint: str) -> httpx.Request: + """Build fallback OAuth metadata discovery request for legacy servers.""" + url = self._build_well_known_fallback_url(well_known_endpoint) + return await self._try_metadata_discovery(url) + + async def _discover_oauth_metadata(self) -> httpx.Request: + """Build OAuth metadata discovery request with fallback support.""" + return await self._discover_well_known_metadata("oauth-authorization-server") + async def _discover_oauth_metadata_fallback(self) -> httpx.Request: """Build fallback OAuth metadata discovery request for legacy servers.""" - base_url = getattr(self.context, "discovery_base_url", "") - if not base_url: - raise OAuthFlowError("No base URL available for fallback discovery") + return await self._discover_well_known_metadata_fallback("oauth-authorization-server") - # Fallback to root discovery for legacy servers - url = urljoin(base_url, "/.well-known/oauth-authorization-server") + async def _discover_oidc_metadata(self) -> httpx.Request: + """ + Build fallback OIDC metadata discovery request. + See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 + """ + return await self._discover_well_known_metadata("openid-configuration") + + async def _discover_oidc_metadata_fallback(self) -> httpx.Request: + """ + Build fallback OIDC metadata discovery request for legacy servers. + See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 + """ + url = self._build_oidc_fallback_url("openid-configuration") return await self._try_metadata_discovery(url) - async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fallback: bool = False) -> bool: + async def _handle_oauth_metadata_response( + self, response: httpx.Response, discovery_stack: OAuthDiscoveryStack + ) -> bool: """Handle OAuth metadata response. Returns True if handled successfully.""" if response.status_code == 200: try: @@ -280,13 +331,10 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fal except ValidationError: pass - # Check if we should attempt fallback (404 on path-aware discovery) - if not is_fallback and self._should_attempt_fallback( - response.status_code, getattr(self.context, "discovery_pathname", "/") - ): - return False # Signal that fallback should be attempted - - return True # Signal no fallback needed (either success or non-404 error) + # Check if we should attempt fallback + # True: No fallback needed (either success or non-404 error) + # False: Signal that fallback should be attempted + return not self._should_attempt_fallback(response.status_code, discovery_stack) async def _register_client(self) -> httpx.Request | None: """Build registration request or skip if already registered.""" @@ -480,6 +528,26 @@ def _add_auth_header(self, request: httpx.Request) -> None: if self.context.current_tokens and self.context.current_tokens.access_token: request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" + def _create_oauth_discovery_stack(self) -> OAuthDiscoveryStack: + """Create a stack of attempts to discover OAuth metadata.""" + discovery_attempts: OAuthDiscoveryStack = [ + # Start with path-aware OAuth discovery + self._discover_oauth_metadata, + # If path-aware discovery fails with 404, try fallback to root + self._discover_oauth_metadata_fallback, + # If root discovery fails with 404, fall back to OIDC 1.0 following + # RFC 8414 path-aware semantics (see RFC 8414 section 5) + self._discover_oidc_metadata, + # If path-aware OIDC discovery failed with 404, fall back to OIDC 1.0 + # following OIDC 1.0 semantics (see RFC 8414 section 5) + self._discover_oidc_metadata_fallback, + ] + + # Reverse the list so we can call pop() without remembering we declared + # this stack backwards for readability + discovery_attempts.reverse() + return discovery_attempts + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: """HTTPX auth flow integration.""" async with self.context.lock: @@ -499,15 +567,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. await self._handle_protected_resource_response(discovery_response) # Step 2: Discover OAuth metadata (with fallback for legacy servers) - oauth_request = await self._discover_oauth_metadata() - oauth_response = yield oauth_request - handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False) - - # If path-aware discovery failed with 404, try fallback to root - if not handled: - fallback_request = await self._discover_oauth_metadata_fallback() - fallback_response = yield fallback_request - await self._handle_oauth_metadata_response(fallback_response, is_fallback=True) + oauth_discovery_stack = self._create_oauth_discovery_stack() + while len(oauth_discovery_stack) > 0: + oauth_discovery = oauth_discovery_stack.pop() + oauth_request = await oauth_discovery() + oauth_response = yield oauth_request + await self._handle_oauth_metadata_response(oauth_response, oauth_discovery_stack) # Step 3: Register client if needed registration_request = await self._register_client() @@ -551,15 +616,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. await self._handle_protected_resource_response(discovery_response) # Step 2: Discover OAuth metadata (with fallback for legacy servers) - oauth_request = await self._discover_oauth_metadata() - oauth_response = yield oauth_request - handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False) - - # If path-aware discovery failed with 404, try fallback to root - if not handled: - fallback_request = await self._discover_oauth_metadata_fallback() - fallback_response = yield fallback_request - await self._handle_oauth_metadata_response(fallback_response, is_fallback=True) + oauth_discovery_stack = self._create_oauth_discovery_stack() + while len(oauth_discovery_stack) > 0: + oauth_discovery = oauth_discovery_stack.pop() + oauth_request = await oauth_discovery() + oauth_response = yield oauth_request + await self._handle_oauth_metadata_response(oauth_response, oauth_discovery_stack) # Step 3: Register client if needed registration_request = await self._register_client() diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 8e6b4f54d..22c6573e0 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -8,7 +8,7 @@ import pytest from pydantic import AnyHttpUrl, AnyUrl -from mcp.client.auth import OAuthClientProvider, PKCEParameters +from mcp.client.auth import OAuthClientProvider, OAuthDiscoveryStack, PKCEParameters from mcp.shared.auth import ( OAuthClientInformationFull, OAuthClientMetadata, @@ -265,6 +265,12 @@ async def callback_handler() -> tuple[str, str | None]: class TestOAuthFallback: """Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers.""" + def _create_discovery_stack(self, oauth_provider, executed_requests=0) -> OAuthDiscoveryStack: + oauth_discovery_stack = oauth_provider._create_oauth_discovery_stack() + for _ in range(executed_requests): + oauth_discovery_stack.pop() # Simulate execution of a discovery request + return oauth_discovery_stack + @pytest.mark.anyio async def test_fallback_discovery_request(self, client_metadata, mock_storage): """Test fallback discovery request building.""" @@ -294,22 +300,150 @@ async def callback_handler() -> tuple[str, str | None]: assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server" assert "mcp-protocol-version" in request.headers + @pytest.mark.anyio + async def test_fallback_oidc_discovery_request(self, client_metadata, mock_storage): + """Test fallback OIDC discovery request building.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + # Set up discovery state manually + provider.context.discovery_base_url = "https://api.example.com" + provider.context.discovery_pathname = "/v1/mcp" + + # Test fallback request building + request = await provider._discover_oidc_metadata() + + assert request.method == "GET" + assert str(request.url) == "https://api.example.com/.well-known/openid-configuration/v1/mcp" + assert "mcp-protocol-version" in request.headers + + @pytest.mark.anyio + async def test_fallback_oidc_legacy_discovery_request(self, client_metadata, mock_storage): + """Test fallback legacy OIDC discovery request building.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + # Set up discovery state manually + provider.context.discovery_base_url = "https://api.example.com" + provider.context.discovery_pathname = "/v1/mcp" + + # Test fallback request building + request = await provider._discover_oidc_metadata_fallback() + + assert request.method == "GET" + assert str(request.url) == "https://api.example.com/v1/mcp/.well-known/openid-configuration" + assert "mcp-protocol-version" in request.headers + + @pytest.mark.anyio + async def test_fallback_oidc_legacy_discovery_request_root(self, client_metadata, mock_storage): + """Test fallback legacy OIDC discovery request building at the root path.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + # Set up discovery state manually + provider.context.discovery_base_url = "https://api.example.com" + provider.context.discovery_pathname = "/" + + # Test fallback request building + request = await provider._discover_oidc_metadata_fallback() + + assert request.method == "GET" + # No prefixing or suffixing at the root + assert str(request.url) == "https://api.example.com/.well-known/openid-configuration" + assert "mcp-protocol-version" in request.headers + + @pytest.mark.anyio + async def test_oauth_discovery_fallback_order(self, oauth_provider): + """Test fallback URL construction order.""" + oauth_discovery_stack = self._create_discovery_stack(oauth_provider) + + request = await oauth_discovery_stack.pop()() + assert request.url.path == "/.well-known/oauth-authorization-server/v1/mcp" + + request = await oauth_discovery_stack.pop()() + assert request.url.path == "/.well-known/oauth-authorization-server" + + request = await oauth_discovery_stack.pop()() + assert request.url.path == "/.well-known/openid-configuration/v1/mcp" + + request = await oauth_discovery_stack.pop()() + assert request.url.path == "/v1/mcp/.well-known/openid-configuration" + + assert len(oauth_discovery_stack) == 0 + @pytest.mark.anyio async def test_should_attempt_fallback(self, oauth_provider): """Test fallback decision logic.""" - # Should attempt fallback on 404 with non-root path - assert oauth_provider._should_attempt_fallback(404, "/v1/mcp") + # Simulate path-aware OAuth discovery execution + oauth_discovery_stack = self._create_discovery_stack(oauth_provider, executed_requests=1) + + # Should attempt fallback on 404 with path-aware OAuth discovery + assert oauth_provider._should_attempt_fallback(404, oauth_discovery_stack) + oauth_discovery_stack.pop(0) # Simulate root OAuth discovery execution + + # Should attempt fallback to path-aware OIDC on 404 with root path + assert oauth_provider._should_attempt_fallback(404, oauth_discovery_stack) + oauth_discovery_stack.pop(0) # Simulate path-aware OIDC discovery execution - # Should NOT attempt fallback on 404 with root path - assert not oauth_provider._should_attempt_fallback(404, "/") + # Should attempt fallback on OIDC 404 with non-root path + assert oauth_provider._should_attempt_fallback(404, oauth_discovery_stack) + oauth_discovery_stack.pop(0) # Simulate root OIDC discovery execution - # Should NOT attempt fallback on other status codes - assert not oauth_provider._should_attempt_fallback(200, "/v1/mcp") - assert not oauth_provider._should_attempt_fallback(500, "/v1/mcp") + # Should NOT attempt fallback on OIDC 404 with root path + assert not oauth_provider._should_attempt_fallback(404, oauth_discovery_stack) + + @pytest.mark.anyio + async def test_should_attempt_fallback_by_status_codes(self, oauth_provider): + """Test fallback decision logic according to status codes.""" + # Simulate path-aware OAuth discovery execution + oauth_discovery_stack = self._create_discovery_stack(oauth_provider, executed_requests=1) + + # Should NOT attempt fallback on status codes other than 404 + assert not oauth_provider._should_attempt_fallback(200, oauth_discovery_stack) + assert not oauth_provider._should_attempt_fallback(500, oauth_discovery_stack) @pytest.mark.anyio async def test_handle_metadata_response_success(self, oauth_provider): """Test successful metadata response handling.""" + # Simulate path-aware OAuth discovery execution + oauth_discovery_stack = self._create_discovery_stack(oauth_provider, executed_requests=1) + # Create minimal valid OAuth metadata content = b"""{ "issuer": "https://auth.example.com", @@ -319,7 +453,7 @@ async def test_handle_metadata_response_success(self, oauth_provider): response = httpx.Response(200, content=content) # Should return True (success) and set metadata - result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False) + result = await oauth_provider._handle_oauth_metadata_response(response, oauth_discovery_stack) assert result is True assert oauth_provider.context.oauth_metadata is not None assert str(oauth_provider.context.oauth_metadata.issuer) == "https://auth.example.com/" @@ -327,6 +461,9 @@ async def test_handle_metadata_response_success(self, oauth_provider): @pytest.mark.anyio async def test_handle_metadata_response_404_needs_fallback(self, oauth_provider): """Test 404 response handling that should trigger fallback.""" + # Simulate path-aware OAuth discovery execution + oauth_discovery_stack = self._create_discovery_stack(oauth_provider, executed_requests=1) + # Set up discovery state for non-root path oauth_provider.context.discovery_base_url = "https://api.example.com" oauth_provider.context.discovery_pathname = "/v1/mcp" @@ -335,12 +472,15 @@ async def test_handle_metadata_response_404_needs_fallback(self, oauth_provider) response = httpx.Response(404) # Should return False (needs fallback) - result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False) + result = await oauth_provider._handle_oauth_metadata_response(response, oauth_discovery_stack) assert result is False @pytest.mark.anyio async def test_handle_metadata_response_404_no_fallback_needed(self, oauth_provider): """Test 404 response handling when no fallback is needed.""" + # Simulate exhausted discovery stack + oauth_discovery_stack = [] + # Set up discovery state for root path oauth_provider.context.discovery_base_url = "https://api.example.com" oauth_provider.context.discovery_pathname = "/" @@ -349,17 +489,7 @@ async def test_handle_metadata_response_404_no_fallback_needed(self, oauth_provi response = httpx.Response(404) # Should return True (no fallback needed) - result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False) - assert result is True - - @pytest.mark.anyio - async def test_handle_metadata_response_404_fallback_attempt(self, oauth_provider): - """Test 404 response handling during fallback attempt.""" - # Mock 404 response during fallback - response = httpx.Response(404) - - # Should return True (fallback attempt complete, no further action needed) - result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=True) + result = await oauth_provider._handle_oauth_metadata_response(response, oauth_discovery_stack) assert result is True @pytest.mark.anyio From 526dbd6694d8fdee2d39bfeda5eefb5fb63ef8f8 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 30 Jun 2025 14:33:33 -0700 Subject: [PATCH 2/2] Add basic test to ensure OIDC metadata is parsable --- tests/shared/test_auth.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/shared/test_auth.py diff --git a/tests/shared/test_auth.py b/tests/shared/test_auth.py new file mode 100644 index 000000000..fd39eb255 --- /dev/null +++ b/tests/shared/test_auth.py @@ -0,0 +1,39 @@ +"""Tests for OAuth 2.0 shared code.""" + +from mcp.shared.auth import OAuthMetadata + + +class TestOAuthMetadata: + """Tests for OAuthMetadata parsing.""" + + def test_oauth(self): + """Should not throw when parsing OAuth metadata.""" + OAuthMetadata.model_validate( + { + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/oauth2/authorize", + "token_endpoint": "https://example.com/oauth2/token", + "scopes_supported": ["read", "write"], + "response_types_supported": ["code", "token"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], + } + ) + + def test_oidc(self): + """Should not throw when parsing OIDC metadata.""" + OAuthMetadata.model_validate( + { + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/oauth2/authorize", + "token_endpoint": "https://example.com/oauth2/token", + "end_session_endpoint": "https://example.com/logout", + "id_token_signing_alg_values_supported": ["RS256"], + "jwks_uri": "https://example.com/.well-known/jwks.json", + "response_types_supported": ["code", "token"], + "revocation_endpoint": "https://example.com/oauth2/revoke", + "scopes_supported": ["openid", "read", "write"], + "subject_types_supported": ["public"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], + "userinfo_endpoint": "https://example.com/oauth2/userInfo", + } + )