Skip to content

Commit 393b2d1

Browse files
committed
fix: fix OAuth flow request handling
1 parent 0b1b52b commit 393b2d1

File tree

2 files changed

+106
-6
lines changed

2 files changed

+106
-6
lines changed

src/mcp/client/auth.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -517,16 +517,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
517517
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
518518
discovery_urls = self._get_discovery_urls()
519519
for url in discovery_urls:
520-
request = self._create_oauth_metadata_request(url)
521-
response = yield request
520+
oauth_metadata_request = self._create_oauth_metadata_request(url)
521+
oauth_metadata_response = yield oauth_metadata_request
522522

523-
if response.status_code == 200:
523+
if oauth_metadata_response.status_code == 200:
524524
try:
525-
await self._handle_oauth_metadata_response(response)
525+
await self._handle_oauth_metadata_response(oauth_metadata_response)
526526
break
527527
except ValidationError:
528528
continue
529-
elif response.status_code != 404:
529+
elif oauth_metadata_response.status_code != 404:
530530
break # Non-404 error, stop trying
531531

532532
# Step 3: Register client if needed

tests/client/test_auth.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import time
6+
from unittest import mock
67

78
import httpx
89
import pytest
@@ -266,7 +267,7 @@ async def test_handle_metadata_response_success(self, oauth_provider):
266267
# Create minimal valid OAuth metadata
267268
content = b"""{
268269
"issuer": "https://auth.example.com",
269-
"authorization_endpoint": "https://auth.example.com/authorize",
270+
"authorization_endpoint": "https://auth.example.com/authorize",
270271
"token_endpoint": "https://auth.example.com/token"
271272
}"""
272273
response = httpx.Response(200, content=content)
@@ -495,6 +496,105 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v
495496
except StopAsyncIteration:
496497
pass # Expected
497498

499+
@pytest.mark.anyio
500+
async def test_auth_flow_with_no_tokens(self, oauth_provider, mock_storage):
501+
"""Test auth flow when no tokens are available, triggering the full OAuth flow."""
502+
# Ensure no tokens are stored
503+
oauth_provider.context.current_tokens = None
504+
oauth_provider.context.token_expiry_time = None
505+
oauth_provider._initialized = True
506+
507+
# Create a test request
508+
test_request = httpx.Request("GET", "https://api.example.com/mcp")
509+
510+
# Mock the auth flow
511+
auth_flow = oauth_provider.async_auth_flow(test_request)
512+
513+
# First request should be the original request without auth header
514+
request = await auth_flow.__anext__()
515+
assert "Authorization" not in request.headers
516+
517+
# Send a 401 response to trigger the OAuth flow
518+
response = httpx.Response(
519+
401,
520+
headers={
521+
"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
522+
},
523+
request=test_request,
524+
)
525+
526+
# Next request should be to discover protected resource metadata
527+
discovery_request = await auth_flow.asend(response)
528+
assert discovery_request.method == "GET"
529+
assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
530+
531+
# Send a successful discovery response with minimal protected resource metadata
532+
discovery_response = httpx.Response(
533+
200,
534+
content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}',
535+
request=discovery_request,
536+
)
537+
538+
# Next request should be to discover OAuth metadata
539+
oauth_metadata_request = await auth_flow.asend(discovery_response)
540+
assert oauth_metadata_request.method == "GET"
541+
assert str(oauth_metadata_request.url).startswith("https://auth.example.com/")
542+
assert "mcp-protocol-version" in oauth_metadata_request.headers
543+
544+
# Send a successful OAuth metadata response
545+
oauth_metadata_response = httpx.Response(
546+
200,
547+
content=(
548+
b'{"issuer": "https://auth.example.com", '
549+
b'"authorization_endpoint": "https://auth.example.com/authorize", '
550+
b'"token_endpoint": "https://auth.example.com/token", '
551+
b'"registration_endpoint": "https://auth.example.com/register"}'
552+
),
553+
request=oauth_metadata_request,
554+
)
555+
556+
# Next request should be to register client
557+
registration_request = await auth_flow.asend(oauth_metadata_response)
558+
assert registration_request.method == "POST"
559+
assert str(registration_request.url) == "https://auth.example.com/register"
560+
561+
# Send a successful registration response
562+
registration_response = httpx.Response(
563+
201,
564+
content=b'{"client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_uris": ["http://localhost:3030/callback"]}',
565+
request=registration_request,
566+
)
567+
568+
# Mock the authorization process
569+
oauth_provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier"))
570+
571+
# Next request should be to exchange token
572+
token_request = await auth_flow.asend(registration_response)
573+
assert token_request.method == "POST"
574+
assert str(token_request.url) == "https://auth.example.com/token"
575+
assert "code=test_auth_code" in token_request.content.decode()
576+
577+
# Send a successful token response
578+
token_response = httpx.Response(
579+
200,
580+
content=(
581+
b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
582+
b'"refresh_token": "new_refresh_token"}'
583+
),
584+
request=token_request,
585+
)
586+
587+
# Final request should be the original request with auth header
588+
final_request = await auth_flow.asend(token_response)
589+
assert final_request.headers["Authorization"] == "Bearer new_access_token"
590+
assert final_request.method == "GET"
591+
assert str(final_request.url) == "https://api.example.com/mcp"
592+
593+
# Verify tokens were stored
594+
assert oauth_provider.context.current_tokens is not None
595+
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
596+
assert oauth_provider.context.token_expiry_time is not None
597+
498598

499599
@pytest.mark.parametrize(
500600
(

0 commit comments

Comments
 (0)