|
3 | 3 | """
|
4 | 4 |
|
5 | 5 | import time
|
| 6 | +from unittest import mock |
6 | 7 |
|
7 | 8 | import httpx
|
8 | 9 | import pytest
|
@@ -266,7 +267,7 @@ async def test_handle_metadata_response_success(self, oauth_provider):
|
266 | 267 | # Create minimal valid OAuth metadata
|
267 | 268 | content = b"""{
|
268 | 269 | "issuer": "https://auth.example.com",
|
269 |
| - "authorization_endpoint": "https://auth.example.com/authorize", |
| 270 | + "authorization_endpoint": "https://auth.example.com/authorize", |
270 | 271 | "token_endpoint": "https://auth.example.com/token"
|
271 | 272 | }"""
|
272 | 273 | response = httpx.Response(200, content=content)
|
@@ -495,6 +496,95 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v
|
495 | 496 | except StopAsyncIteration:
|
496 | 497 | pass # Expected
|
497 | 498 |
|
| 499 | + @pytest.mark.anyio |
| 500 | + async def test_auth_flow_with_no_tokens(self, oauth_provider, mock_storage): |
| 501 | + """Test auth flow when no tokens are available, triggering the full OAuth flow.""" |
| 502 | + # Ensure no tokens are stored |
| 503 | + oauth_provider.context.current_tokens = None |
| 504 | + oauth_provider.context.token_expiry_time = None |
| 505 | + oauth_provider._initialized = True |
| 506 | + |
| 507 | + # Create a test request |
| 508 | + test_request = httpx.Request("GET", "https://api.example.com/mcp") |
| 509 | + |
| 510 | + # Mock the auth flow |
| 511 | + auth_flow = oauth_provider.async_auth_flow(test_request) |
| 512 | + |
| 513 | + # First request should be the original request without auth header |
| 514 | + request = await auth_flow.__anext__() |
| 515 | + assert "Authorization" not in request.headers |
| 516 | + |
| 517 | + # Send a 401 response to trigger the OAuth flow |
| 518 | + response = httpx.Response( |
| 519 | + 401, |
| 520 | + headers={"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'}, |
| 521 | + request=test_request |
| 522 | + ) |
| 523 | + |
| 524 | + # Next request should be to discover protected resource metadata |
| 525 | + discovery_request = await auth_flow.asend(response) |
| 526 | + assert discovery_request.method == "GET" |
| 527 | + assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource" |
| 528 | + |
| 529 | + # Send a successful discovery response with minimal protected resource metadata |
| 530 | + discovery_response = httpx.Response( |
| 531 | + 200, |
| 532 | + content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}', |
| 533 | + request=discovery_request |
| 534 | + ) |
| 535 | + |
| 536 | + # Next request should be to discover OAuth metadata |
| 537 | + oauth_metadata_request = await auth_flow.asend(discovery_response) |
| 538 | + assert oauth_metadata_request.method == "GET" |
| 539 | + assert str(oauth_metadata_request.url).startswith("https://auth.example.com/") |
| 540 | + assert "mcp-protocol-version" in oauth_metadata_request.headers |
| 541 | + |
| 542 | + # Send a successful OAuth metadata response |
| 543 | + oauth_metadata_response = httpx.Response( |
| 544 | + 200, |
| 545 | + content=b'{"issuer": "https://auth.example.com", "authorization_endpoint": "https://auth.example.com/authorize", "token_endpoint": "https://auth.example.com/token", "registration_endpoint": "https://auth.example.com/register"}', |
| 546 | + request=oauth_metadata_request |
| 547 | + ) |
| 548 | + |
| 549 | + # Next request should be to register client |
| 550 | + registration_request = await auth_flow.asend(oauth_metadata_response) |
| 551 | + assert registration_request.method == "POST" |
| 552 | + assert str(registration_request.url) == "https://auth.example.com/register" |
| 553 | + |
| 554 | + # Send a successful registration response |
| 555 | + registration_response = httpx.Response( |
| 556 | + 201, |
| 557 | + content=b'{"client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_uris": ["http://localhost:3030/callback"]}', |
| 558 | + request=registration_request |
| 559 | + ) |
| 560 | + |
| 561 | + # Mock the authorization process |
| 562 | + oauth_provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier")) |
| 563 | + |
| 564 | + # Next request should be to exchange token |
| 565 | + token_request = await auth_flow.asend(registration_response) |
| 566 | + assert token_request.method == "POST" |
| 567 | + assert str(token_request.url) == "https://auth.example.com/token" |
| 568 | + assert "code=test_auth_code" in token_request.content.decode() |
| 569 | + |
| 570 | + # Send a successful token response |
| 571 | + token_response = httpx.Response( |
| 572 | + 200, |
| 573 | + content=b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, "refresh_token": "new_refresh_token"}', |
| 574 | + request=token_request |
| 575 | + ) |
| 576 | + |
| 577 | + # Final request should be the original request with auth header |
| 578 | + final_request = await auth_flow.asend(token_response) |
| 579 | + assert final_request.headers["Authorization"] == "Bearer new_access_token" |
| 580 | + assert final_request.method == "GET" |
| 581 | + assert str(final_request.url) == "https://api.example.com/mcp" |
| 582 | + |
| 583 | + # Verify tokens were stored |
| 584 | + assert oauth_provider.context.current_tokens is not None |
| 585 | + assert oauth_provider.context.current_tokens.access_token == "new_access_token" |
| 586 | + assert oauth_provider.context.token_expiry_time is not None |
| 587 | + |
498 | 588 |
|
499 | 589 | @pytest.mark.parametrize(
|
500 | 590 | (
|
|
0 commit comments