From 393b2d1688fbf46edd73c39f7050c7833783d3be Mon Sep 17 00:00:00 2001 From: Clare Liguori Date: Fri, 18 Jul 2025 21:24:56 -0700 Subject: [PATCH] fix: fix OAuth flow request handling --- src/mcp/client/auth.py | 10 ++-- tests/client/test_auth.py | 102 +++++++++++++++++++++++++++++++++++++- 2 files changed, 106 insertions(+), 6 deletions(-) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index e31709e05..b00db7b9b 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -517,16 +517,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 2: Discover OAuth metadata (with fallback for legacy servers) discovery_urls = self._get_discovery_urls() for url in discovery_urls: - request = self._create_oauth_metadata_request(url) - response = yield request + oauth_metadata_request = self._create_oauth_metadata_request(url) + oauth_metadata_response = yield oauth_metadata_request - if response.status_code == 200: + if oauth_metadata_response.status_code == 200: try: - await self._handle_oauth_metadata_response(response) + await self._handle_oauth_metadata_response(oauth_metadata_response) break except ValidationError: continue - elif response.status_code != 404: + elif oauth_metadata_response.status_code != 404: break # Non-404 error, stop trying # Step 3: Register client if needed diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index c47007a4c..46208d69c 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -3,6 +3,7 @@ """ import time +from unittest import mock import httpx import pytest @@ -266,7 +267,7 @@ async def test_handle_metadata_response_success(self, oauth_provider): # Create minimal valid OAuth metadata content = b"""{ "issuer": "https://auth.example.com", - "authorization_endpoint": "https://auth.example.com/authorize", + "authorization_endpoint": "https://auth.example.com/authorize", "token_endpoint": "https://auth.example.com/token" }""" response = httpx.Response(200, content=content) @@ -495,6 +496,105 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v except StopAsyncIteration: pass # Expected + @pytest.mark.anyio + async def test_auth_flow_with_no_tokens(self, oauth_provider, mock_storage): + """Test auth flow when no tokens are available, triggering the full OAuth flow.""" + # Ensure no tokens are stored + oauth_provider.context.current_tokens = None + oauth_provider.context.token_expiry_time = None + oauth_provider._initialized = True + + # Create a test request + test_request = httpx.Request("GET", "https://api.example.com/mcp") + + # Mock the auth flow + auth_flow = oauth_provider.async_auth_flow(test_request) + + # First request should be the original request without auth header + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Send a 401 response to trigger the OAuth flow + response = httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + }, + request=test_request, + ) + + # Next request should be to discover protected resource metadata + discovery_request = await auth_flow.asend(response) + assert discovery_request.method == "GET" + assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource" + + # Send a successful discovery response with minimal protected resource metadata + discovery_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}', + request=discovery_request, + ) + + # Next request should be to discover OAuth metadata + oauth_metadata_request = await auth_flow.asend(discovery_response) + assert oauth_metadata_request.method == "GET" + assert str(oauth_metadata_request.url).startswith("https://auth.example.com/") + assert "mcp-protocol-version" in oauth_metadata_request.headers + + # Send a successful OAuth metadata response + oauth_metadata_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register"}' + ), + request=oauth_metadata_request, + ) + + # Next request should be to register client + registration_request = await auth_flow.asend(oauth_metadata_response) + assert registration_request.method == "POST" + assert str(registration_request.url) == "https://auth.example.com/register" + + # Send a successful registration response + registration_response = httpx.Response( + 201, + content=b'{"client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_uris": ["http://localhost:3030/callback"]}', + request=registration_request, + ) + + # Mock the authorization process + oauth_provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier")) + + # Next request should be to exchange token + token_request = await auth_flow.asend(registration_response) + assert token_request.method == "POST" + assert str(token_request.url) == "https://auth.example.com/token" + assert "code=test_auth_code" in token_request.content.decode() + + # Send a successful token response + token_response = httpx.Response( + 200, + content=( + b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, ' + b'"refresh_token": "new_refresh_token"}' + ), + request=token_request, + ) + + # Final request should be the original request with auth header + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer new_access_token" + assert final_request.method == "GET" + assert str(final_request.url) == "https://api.example.com/mcp" + + # Verify tokens were stored + assert oauth_provider.context.current_tokens is not None + assert oauth_provider.context.current_tokens.access_token == "new_access_token" + assert oauth_provider.context.token_expiry_time is not None + @pytest.mark.parametrize( (