Skip to content

Commit 51961c7

Browse files
committed
fix: fix OAuth flow request handling
1 parent 0b1b52b commit 51961c7

File tree

2 files changed

+96
-6
lines changed

2 files changed

+96
-6
lines changed

src/mcp/client/auth.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -517,16 +517,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
517517
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
518518
discovery_urls = self._get_discovery_urls()
519519
for url in discovery_urls:
520-
request = self._create_oauth_metadata_request(url)
521-
response = yield request
520+
oauth_metadata_request = self._create_oauth_metadata_request(url)
521+
oauth_metadata_response = yield oauth_metadata_request
522522

523-
if response.status_code == 200:
523+
if oauth_metadata_response.status_code == 200:
524524
try:
525-
await self._handle_oauth_metadata_response(response)
525+
await self._handle_oauth_metadata_response(oauth_metadata_response)
526526
break
527527
except ValidationError:
528528
continue
529-
elif response.status_code != 404:
529+
elif oauth_metadata_response.status_code != 404:
530530
break # Non-404 error, stop trying
531531

532532
# Step 3: Register client if needed

tests/client/test_auth.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import time
6+
from unittest import mock
67

78
import httpx
89
import pytest
@@ -266,7 +267,7 @@ async def test_handle_metadata_response_success(self, oauth_provider):
266267
# Create minimal valid OAuth metadata
267268
content = b"""{
268269
"issuer": "https://auth.example.com",
269-
"authorization_endpoint": "https://auth.example.com/authorize",
270+
"authorization_endpoint": "https://auth.example.com/authorize",
270271
"token_endpoint": "https://auth.example.com/token"
271272
}"""
272273
response = httpx.Response(200, content=content)
@@ -495,6 +496,95 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v
495496
except StopAsyncIteration:
496497
pass # Expected
497498

499+
@pytest.mark.anyio
500+
async def test_auth_flow_with_no_tokens(self, oauth_provider, mock_storage):
501+
"""Test auth flow when no tokens are available, triggering the full OAuth flow."""
502+
# Ensure no tokens are stored
503+
oauth_provider.context.current_tokens = None
504+
oauth_provider.context.token_expiry_time = None
505+
oauth_provider._initialized = True
506+
507+
# Create a test request
508+
test_request = httpx.Request("GET", "https://api.example.com/mcp")
509+
510+
# Mock the auth flow
511+
auth_flow = oauth_provider.async_auth_flow(test_request)
512+
513+
# First request should be the original request without auth header
514+
request = await auth_flow.__anext__()
515+
assert "Authorization" not in request.headers
516+
517+
# Send a 401 response to trigger the OAuth flow
518+
response = httpx.Response(
519+
401,
520+
headers={"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'},
521+
request=test_request
522+
)
523+
524+
# Next request should be to discover protected resource metadata
525+
discovery_request = await auth_flow.asend(response)
526+
assert discovery_request.method == "GET"
527+
assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
528+
529+
# Send a successful discovery response with minimal protected resource metadata
530+
discovery_response = httpx.Response(
531+
200,
532+
content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}',
533+
request=discovery_request
534+
)
535+
536+
# Next request should be to discover OAuth metadata
537+
oauth_metadata_request = await auth_flow.asend(discovery_response)
538+
assert oauth_metadata_request.method == "GET"
539+
assert str(oauth_metadata_request.url).startswith("https://auth.example.com/")
540+
assert "mcp-protocol-version" in oauth_metadata_request.headers
541+
542+
# Send a successful OAuth metadata response
543+
oauth_metadata_response = httpx.Response(
544+
200,
545+
content=b'{"issuer": "https://auth.example.com", "authorization_endpoint": "https://auth.example.com/authorize", "token_endpoint": "https://auth.example.com/token", "registration_endpoint": "https://auth.example.com/register"}',
546+
request=oauth_metadata_request
547+
)
548+
549+
# Next request should be to register client
550+
registration_request = await auth_flow.asend(oauth_metadata_response)
551+
assert registration_request.method == "POST"
552+
assert str(registration_request.url) == "https://auth.example.com/register"
553+
554+
# Send a successful registration response
555+
registration_response = httpx.Response(
556+
201,
557+
content=b'{"client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_uris": ["http://localhost:3030/callback"]}',
558+
request=registration_request
559+
)
560+
561+
# Mock the authorization process
562+
oauth_provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier"))
563+
564+
# Next request should be to exchange token
565+
token_request = await auth_flow.asend(registration_response)
566+
assert token_request.method == "POST"
567+
assert str(token_request.url) == "https://auth.example.com/token"
568+
assert "code=test_auth_code" in token_request.content.decode()
569+
570+
# Send a successful token response
571+
token_response = httpx.Response(
572+
200,
573+
content=b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, "refresh_token": "new_refresh_token"}',
574+
request=token_request
575+
)
576+
577+
# Final request should be the original request with auth header
578+
final_request = await auth_flow.asend(token_response)
579+
assert final_request.headers["Authorization"] == "Bearer new_access_token"
580+
assert final_request.method == "GET"
581+
assert str(final_request.url) == "https://api.example.com/mcp"
582+
583+
# Verify tokens were stored
584+
assert oauth_provider.context.current_tokens is not None
585+
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
586+
assert oauth_provider.context.token_expiry_time is not None
587+
498588

499589
@pytest.mark.parametrize(
500590
(

0 commit comments

Comments
 (0)