Skip to content

fix: fix OAuth flow request object handling #1174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,16 +517,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
discovery_urls = self._get_discovery_urls()
for url in discovery_urls:
request = self._create_oauth_metadata_request(url)
response = yield request
oauth_metadata_request = self._create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request

if response.status_code == 200:
if oauth_metadata_response.status_code == 200:
try:
await self._handle_oauth_metadata_response(response)
await self._handle_oauth_metadata_response(oauth_metadata_response)
break
except ValidationError:
continue
elif response.status_code != 404:
elif oauth_metadata_response.status_code != 404:
break # Non-404 error, stop trying

# Step 3: Register client if needed
Expand Down
102 changes: 101 additions & 1 deletion tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import time
from unittest import mock

import httpx
import pytest
Expand Down Expand Up @@ -266,7 +267,7 @@ async def test_handle_metadata_response_success(self, oauth_provider):
# Create minimal valid OAuth metadata
content = b"""{
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/authorize",
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token"
}"""
response = httpx.Response(200, content=content)
Expand Down Expand Up @@ -495,6 +496,105 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v
except StopAsyncIteration:
pass # Expected

@pytest.mark.anyio
async def test_auth_flow_with_no_tokens(self, oauth_provider, mock_storage):
"""Test auth flow when no tokens are available, triggering the full OAuth flow."""
# Ensure no tokens are stored
oauth_provider.context.current_tokens = None
oauth_provider.context.token_expiry_time = None
oauth_provider._initialized = True

# Create a test request
test_request = httpx.Request("GET", "https://api.example.com/mcp")

# Mock the auth flow
auth_flow = oauth_provider.async_auth_flow(test_request)

# First request should be the original request without auth header
request = await auth_flow.__anext__()
assert "Authorization" not in request.headers

# Send a 401 response to trigger the OAuth flow
response = httpx.Response(
401,
headers={
"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
},
request=test_request,
)

# Next request should be to discover protected resource metadata
discovery_request = await auth_flow.asend(response)
assert discovery_request.method == "GET"
assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource"

# Send a successful discovery response with minimal protected resource metadata
discovery_response = httpx.Response(
200,
content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}',
request=discovery_request,
)

# Next request should be to discover OAuth metadata
oauth_metadata_request = await auth_flow.asend(discovery_response)
assert oauth_metadata_request.method == "GET"
assert str(oauth_metadata_request.url).startswith("https://auth.example.com/")
assert "mcp-protocol-version" in oauth_metadata_request.headers

# Send a successful OAuth metadata response
oauth_metadata_response = httpx.Response(
200,
content=(
b'{"issuer": "https://auth.example.com", '
b'"authorization_endpoint": "https://auth.example.com/authorize", '
b'"token_endpoint": "https://auth.example.com/token", '
b'"registration_endpoint": "https://auth.example.com/register"}'
),
request=oauth_metadata_request,
)

# Next request should be to register client
registration_request = await auth_flow.asend(oauth_metadata_response)
assert registration_request.method == "POST"
assert str(registration_request.url) == "https://auth.example.com/register"

# Send a successful registration response
registration_response = httpx.Response(
201,
content=b'{"client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_uris": ["http://localhost:3030/callback"]}',
request=registration_request,
)

# Mock the authorization process
oauth_provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier"))

# Next request should be to exchange token
token_request = await auth_flow.asend(registration_response)
assert token_request.method == "POST"
assert str(token_request.url) == "https://auth.example.com/token"
assert "code=test_auth_code" in token_request.content.decode()

# Send a successful token response
token_response = httpx.Response(
200,
content=(
b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
b'"refresh_token": "new_refresh_token"}'
),
request=token_request,
)

# Final request should be the original request with auth header
final_request = await auth_flow.asend(token_response)
assert final_request.headers["Authorization"] == "Bearer new_access_token"
assert final_request.method == "GET"
assert str(final_request.url) == "https://api.example.com/mcp"

# Verify tokens were stored
assert oauth_provider.context.current_tokens is not None
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
assert oauth_provider.context.token_expiry_time is not None


@pytest.mark.parametrize(
(
Expand Down
Loading