@@ -235,107 +235,30 @@ async def callback_handler() -> tuple[str, str | None]:
235
235
assert "mcp-protocol-version" in request .headers
236
236
237
237
@pytest .mark .anyio
238
- async def test_discover_oauth_metadata_request (self , oauth_provider ):
238
+ def test_create_oauth_metadata_request (self , oauth_provider ):
239
239
"""Test OAuth metadata discovery request building."""
240
- request = await oauth_provider ._discover_oauth_metadata ( )
240
+ request = oauth_provider ._create_oauth_metadata_request ( "https://example.com" )
241
241
242
+ # Ensure correct method and headers, and that the URL is unmodified
242
243
assert request .method == "GET"
243
- assert str (request .url ) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
244
- assert "mcp-protocol-version" in request .headers
245
-
246
- @pytest .mark .anyio
247
- async def test_discover_oauth_metadata_request_no_path (self , client_metadata , mock_storage ):
248
- """Test OAuth metadata discovery request building when server has no path."""
249
-
250
- async def redirect_handler (url : str ) -> None :
251
- pass
252
-
253
- async def callback_handler () -> tuple [str , str | None ]:
254
- return "test_auth_code" , "test_state"
255
-
256
- provider = OAuthClientProvider (
257
- server_url = "https://api.example.com" ,
258
- client_metadata = client_metadata ,
259
- storage = mock_storage ,
260
- redirect_handler = redirect_handler ,
261
- callback_handler = callback_handler ,
262
- )
263
-
264
- request = await provider ._discover_oauth_metadata ()
265
-
266
- assert request .method == "GET"
267
- assert str (request .url ) == "https://api.example.com/.well-known/oauth-authorization-server"
268
- assert "mcp-protocol-version" in request .headers
269
-
270
- @pytest .mark .anyio
271
- async def test_discover_oauth_metadata_request_trailing_slash (self , client_metadata , mock_storage ):
272
- """Test OAuth metadata discovery request building when server path has trailing slash."""
273
-
274
- async def redirect_handler (url : str ) -> None :
275
- pass
276
-
277
- async def callback_handler () -> tuple [str , str | None ]:
278
- return "test_auth_code" , "test_state"
279
-
280
- provider = OAuthClientProvider (
281
- server_url = "https://api.example.com/v1/mcp/" ,
282
- client_metadata = client_metadata ,
283
- storage = mock_storage ,
284
- redirect_handler = redirect_handler ,
285
- callback_handler = callback_handler ,
286
- )
287
-
288
- request = await provider ._discover_oauth_metadata ()
289
-
290
- assert request .method == "GET"
291
- assert str (request .url ) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
244
+ assert str (request .url ) == "https://example.com"
292
245
assert "mcp-protocol-version" in request .headers
293
246
294
247
295
248
class TestOAuthFallback :
296
249
"""Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers."""
297
250
298
251
@pytest .mark .anyio
299
- async def test_fallback_discovery_request (self , client_metadata , mock_storage ):
300
- """Test fallback discovery request building."""
301
-
302
- async def redirect_handler (url : str ) -> None :
303
- pass
304
-
305
- async def callback_handler () -> tuple [str , str | None ]:
306
- return "test_auth_code" , "test_state"
307
-
308
- provider = OAuthClientProvider (
309
- server_url = "https://api.example.com/v1/mcp" ,
310
- client_metadata = client_metadata ,
311
- storage = mock_storage ,
312
- redirect_handler = redirect_handler ,
313
- callback_handler = callback_handler ,
314
- )
315
-
316
- # Set up discovery state manually as if path-aware discovery was attempted
317
- provider .context .discovery_base_url = "https://api.example.com"
318
- provider .context .discovery_pathname = "/v1/mcp"
252
+ async def test_oauth_discovery_fallback_order (self , oauth_provider ):
253
+ """Test fallback URL construction order."""
254
+ discovery_urls = oauth_provider ._get_discovery_urls ()
319
255
320
- # Test fallback request building
321
- request = await provider ._discover_oauth_metadata_fallback ()
322
-
323
- assert request .method == "GET"
324
- assert str (request .url ) == "https://api.example.com/.well-known/oauth-authorization-server"
325
- assert "mcp-protocol-version" in request .headers
326
-
327
- @pytest .mark .anyio
328
- async def test_should_attempt_fallback (self , oauth_provider ):
329
- """Test fallback decision logic."""
330
- # Should attempt fallback on 404 with non-root path
331
- assert oauth_provider ._should_attempt_fallback (404 , "/v1/mcp" )
332
-
333
- # Should NOT attempt fallback on 404 with root path
334
- assert not oauth_provider ._should_attempt_fallback (404 , "/" )
335
-
336
- # Should NOT attempt fallback on other status codes
337
- assert not oauth_provider ._should_attempt_fallback (200 , "/v1/mcp" )
338
- assert not oauth_provider ._should_attempt_fallback (500 , "/v1/mcp" )
256
+ assert discovery_urls == [
257
+ "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp" ,
258
+ "https://api.example.com/.well-known/oauth-authorization-server" ,
259
+ "https://api.example.com/.well-known/openid-configuration/v1/mcp" ,
260
+ "https://api.example.com/v1/mcp/.well-known/openid-configuration" ,
261
+ ]
339
262
340
263
@pytest .mark .anyio
341
264
async def test_handle_metadata_response_success (self , oauth_provider ):
@@ -348,50 +271,11 @@ async def test_handle_metadata_response_success(self, oauth_provider):
348
271
}"""
349
272
response = httpx .Response (200 , content = content )
350
273
351
- # Should return True (success) and set metadata
352
- result = await oauth_provider ._handle_oauth_metadata_response (response , is_fallback = False )
353
- assert result is True
274
+ # Should set metadata
275
+ await oauth_provider ._handle_oauth_metadata_response (response )
354
276
assert oauth_provider .context .oauth_metadata is not None
355
277
assert str (oauth_provider .context .oauth_metadata .issuer ) == "https://auth.example.com/"
356
278
357
- @pytest .mark .anyio
358
- async def test_handle_metadata_response_404_needs_fallback (self , oauth_provider ):
359
- """Test 404 response handling that should trigger fallback."""
360
- # Set up discovery state for non-root path
361
- oauth_provider .context .discovery_base_url = "https://api.example.com"
362
- oauth_provider .context .discovery_pathname = "/v1/mcp"
363
-
364
- # Mock 404 response
365
- response = httpx .Response (404 )
366
-
367
- # Should return False (needs fallback)
368
- result = await oauth_provider ._handle_oauth_metadata_response (response , is_fallback = False )
369
- assert result is False
370
-
371
- @pytest .mark .anyio
372
- async def test_handle_metadata_response_404_no_fallback_needed (self , oauth_provider ):
373
- """Test 404 response handling when no fallback is needed."""
374
- # Set up discovery state for root path
375
- oauth_provider .context .discovery_base_url = "https://api.example.com"
376
- oauth_provider .context .discovery_pathname = "/"
377
-
378
- # Mock 404 response
379
- response = httpx .Response (404 )
380
-
381
- # Should return True (no fallback needed)
382
- result = await oauth_provider ._handle_oauth_metadata_response (response , is_fallback = False )
383
- assert result is True
384
-
385
- @pytest .mark .anyio
386
- async def test_handle_metadata_response_404_fallback_attempt (self , oauth_provider ):
387
- """Test 404 response handling during fallback attempt."""
388
- # Mock 404 response during fallback
389
- response = httpx .Response (404 )
390
-
391
- # Should return True (fallback attempt complete, no further action needed)
392
- result = await oauth_provider ._handle_oauth_metadata_response (response , is_fallback = True )
393
- assert result is True
394
-
395
279
@pytest .mark .anyio
396
280
async def test_register_client_request (self , oauth_provider ):
397
281
"""Test client registration request building."""
0 commit comments