diff --git a/README.md b/README.md index 3cad2aade..ace68d8e8 100644 --- a/README.md +++ b/README.md @@ -1285,7 +1285,11 @@ This ensures your client UI shows the most user-friendly names that servers prov The SDK includes [authorization support](https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization) for connecting to protected MCP servers: ```python -from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.auth import ( + OAuthClientProvider, + TokenExchangeProvider, + TokenStorage, +) from mcp.client.session import ClientSession from mcp.client.streamable_http import streamablehttp_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken @@ -1322,6 +1326,24 @@ async def main(): callback_handler=lambda: ("auth_code", None), ) + # For machine-to-machine scenarios, use ClientCredentialsProvider + # instead of OAuthClientProvider. + + # If you already have a user token from another provider, you can + # exchange it for an MCP token using the token_exchange grant + # implemented by TokenExchangeProvider. + token_exchange_auth = TokenExchangeProvider( + server_url="https://api.example.com", + client_metadata=OAuthClientMetadata( + client_name="My Client", + redirect_uris=["http://localhost:3000/callback"], + grant_types=["client_credentials", "token_exchange"], + response_types=["code"], + ), + storage=CustomTokenStorage(), + subject_token_supplier=lambda: "user_token", + ) + # Use with streamable HTTP client async with streamablehttp_client( "https://api.example.com/mcp", auth=oauth_auth diff --git a/docs/api.md b/docs/api.md index 3f696af54..3291f5c01 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1 +1,5 @@ +The Python SDK exposes the entire `mcp` package for use in your own projects. +It includes an OAuth server implementation with support for the RFC 8693 +`token_exchange` grant type. + ::: mcp diff --git a/docs/index.md b/docs/index.md index 42ad9ca0c..dc0ffea32 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,3 +3,7 @@ This is the MCP Server implementation in Python. It only contains the [API Reference](api.md) for the time being. + +The built-in OAuth server supports the RFC 8693 `token_exchange` grant type, +allowing clients to exchange user tokens from external providers for MCP +access tokens. diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py index aa813b542..ccd01afb8 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -238,6 +238,52 @@ async def exchange_authorization_code( scope=" ".join(authorization_code.scopes), ) + async def exchange_client_credentials(self, client: OAuthClientInformationFull, scopes: list[str]) -> OAuthToken: + """Exchange client credentials for an MCP access token.""" + mcp_token = f"mcp_{secrets.token_hex(32)}" + self.tokens[mcp_token] = AccessToken( + token=mcp_token, + client_id=client.client_id, + scopes=scopes, + expires_at=int(time.time()) + 3600, + ) + return OAuthToken( + access_token=mcp_token, + token_type="Bearer", + expires_in=3600, + scope=" ".join(scopes), + ) + + async def exchange_token( + self, + client: OAuthClientInformationFull, + subject_token: str, + subject_token_type: str, + actor_token: str | None, + actor_token_type: str | None, + scope: list[str] | None, + audience: str | None, + resource: str | None, + ) -> OAuthToken: + """Exchange an external token for an MCP access token.""" + if not subject_token: + raise ValueError("Invalid subject token") + + mcp_token = f"mcp_{secrets.token_hex(32)}" + self.tokens[mcp_token] = AccessToken( + token=mcp_token, + client_id=client.client_id, + scopes=scope or [self.settings.mcp_scope], + expires_at=int(time.time()) + 3600, + resource=resource, + ) + return OAuthToken( + access_token=mcp_token, + token_type="Bearer", + expires_in=3600, + scope=" ".join(scope or [self.settings.mcp_scope]), + ) + async def load_access_token(self, token: str) -> AccessToken | None: """Load and validate an access token.""" access_token = self.tokens.get(token) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 8bafe18eb..d08daf7be 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -124,7 +124,7 @@ def update_token_expiry(self, token: OAuthToken) -> None: self.token_expiry_time = None def is_token_valid(self) -> bool: - """Check if current token is valid.""" + """Check if the current token is valid.""" return bool( self.current_tokens and self.current_tokens.access_token @@ -132,7 +132,7 @@ def is_token_valid(self) -> bool: ) def can_refresh_token(self) -> bool: - """Check if token can be refreshed.""" + """Check if the token can be refreshed.""" return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info) def clear_tokens(self) -> None: @@ -579,3 +579,283 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Retry with new tokens self._add_auth_header(request) yield request + + +class ClientCredentialsProvider(httpx.Auth): + """HTTPX auth using the OAuth2 client credentials grant.""" + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + resource: str | None = None, + timeout: float = 300.0, + ): + self.server_url = server_url + self.client_metadata = client_metadata + self.storage = storage + self.timeout = timeout + self.resource = resource or resource_url_from_server_url(server_url) + + self._current_tokens: OAuthToken | None = None + self._metadata: OAuthMetadata | None = None + self._client_info: OAuthClientInformationFull | None = None + self._token_expiry_time: float | None = None + + self._token_lock = anyio.Lock() + + def _get_authorization_base_url(self, server_url: str) -> str: + """Return base authorization server URL without path.""" + parsed = urlparse(server_url) + return f"{parsed.scheme}://{parsed.netloc}" + + async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None: + """Discover OAuth server metadata for client credentials.""" + auth_base_url = self._get_authorization_base_url(server_url) + url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server") + headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION} + + async with httpx.AsyncClient() as client: + try: + response = await client.get(url, headers=headers) + if response.status_code == 404: + return None + response.raise_for_status() + return OAuthMetadata.model_validate(response.json()) + except Exception: + try: + response = await client.get(url) + if response.status_code == 404: + return None + response.raise_for_status() + return OAuthMetadata.model_validate(response.json()) + except Exception: + logger.exception("Failed to discover OAuth metadata") + return None + + async def _register_oauth_client( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + metadata: OAuthMetadata | None = None, + ) -> OAuthClientInformationFull: + if not metadata: + metadata = await self._discover_oauth_metadata(server_url) + + if metadata and metadata.registration_endpoint: + registration_url = str(metadata.registration_endpoint) + else: + auth_base_url = self._get_authorization_base_url(server_url) + registration_url = urljoin(auth_base_url, "/register") + + if client_metadata.scope is None and metadata and metadata.scopes_supported is not None: + client_metadata.scope = " ".join(metadata.scopes_supported) + + registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) + + async with httpx.AsyncClient() as client: + response = await client.post( + registration_url, + json=registration_data, + headers={"Content-Type": "application/json"}, + ) + + if response.status_code not in (200, 201): + raise httpx.HTTPStatusError( + f"Registration failed: {response.status_code}", + request=response.request, + response=response, + ) + + return OAuthClientInformationFull.model_validate(response.json()) + + def _has_valid_token(self) -> bool: + if not self._current_tokens or not self._current_tokens.access_token: + return False + + if self._token_expiry_time and time.time() > self._token_expiry_time: + return False + return True + + async def _validate_token_scopes(self, token_response: OAuthToken) -> None: + if not token_response.scope: + return + + requested_scopes: set[str] = set() + if self.client_metadata.scope: + requested_scopes = set(self.client_metadata.scope.split()) + returned_scopes = set(token_response.scope.split()) + unauthorized_scopes = returned_scopes - requested_scopes + if unauthorized_scopes: + raise Exception(f"Server granted unauthorized scopes: {unauthorized_scopes}.") + else: + granted = set(token_response.scope.split()) + logger.debug( + "No explicit scopes requested, accepting server-granted scopes: %s", + granted, + ) + + async def initialize(self) -> None: + self._current_tokens = await self.storage.get_tokens() + self._client_info = await self.storage.get_client_info() + + async def _get_or_register_client(self) -> OAuthClientInformationFull: + if not self._client_info: + self._client_info = await self._register_oauth_client(self.server_url, self.client_metadata, self._metadata) + await self.storage.set_client_info(self._client_info) + return self._client_info + + async def _request_token(self) -> None: + if not self._metadata: + self._metadata = await self._discover_oauth_metadata(self.server_url) + + client_info = await self._get_or_register_client() + + if self._metadata and self._metadata.token_endpoint: + token_url = str(self._metadata.token_endpoint) + else: + auth_base_url = self._get_authorization_base_url(self.server_url) + token_url = urljoin(auth_base_url, "/token") + + token_data = { + "grant_type": "client_credentials", + "client_id": client_info.client_id, + "resource": self.resource, + } + + if client_info.client_secret: + token_data["client_secret"] = client_info.client_secret + + if self.client_metadata.scope: + token_data["scope"] = self.client_metadata.scope + + async with httpx.AsyncClient() as client: + response = await client.post( + token_url, + data=token_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30.0, + ) + + if response.status_code != 200: + raise Exception(f"Token request failed: {response.status_code} {response.text}") + + token_response = OAuthToken.model_validate(response.json()) + await self._validate_token_scopes(token_response) + + if token_response.expires_in: + self._token_expiry_time = time.time() + token_response.expires_in + else: + self._token_expiry_time = None + + await self.storage.set_tokens(token_response) + self._current_tokens = token_response + + async def ensure_token(self) -> None: + async with self._token_lock: + if self._has_valid_token(): + return + await self._request_token() + + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: + if not self._has_valid_token(): + await self.initialize() + await self.ensure_token() + + if self._current_tokens and self._current_tokens.access_token: + request.headers["Authorization"] = f"Bearer {self._current_tokens.access_token}" + + response = yield request + + if response.status_code == 401: + self._current_tokens = None + + +class TokenExchangeProvider(ClientCredentialsProvider): + """OAuth2 token exchange based on RFC 8693.""" + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + subject_token_supplier: Callable[[], Awaitable[str]], + subject_token_type: str = "access_token", + actor_token_supplier: Callable[[], Awaitable[str]] | None = None, + actor_token_type: str | None = None, + audience: str | None = None, + resource: str | None = None, + timeout: float = 300.0, + ): + """Create a new token exchange provider. + + Parameters are forwarded to ClientCredentialsProvider for + client authentication. The resource parameter binds issued tokens to + the target resource, as defined by RFC 8707. + """ + + super().__init__(server_url, client_metadata, storage, resource, timeout) + self.subject_token_supplier = subject_token_supplier + self.subject_token_type = subject_token_type + self.actor_token_supplier = actor_token_supplier + self.actor_token_type = actor_token_type + self.audience = audience + + async def _request_token(self) -> None: + if not self._metadata: + self._metadata = await self._discover_oauth_metadata(self.server_url) + + client_info = await self._get_or_register_client() + + if self._metadata and self._metadata.token_endpoint: + token_url = str(self._metadata.token_endpoint) + else: + auth_base_url = self._get_authorization_base_url(self.server_url) + token_url = urljoin(auth_base_url, "/token") + + subject_token = await self.subject_token_supplier() + actor_token = await self.actor_token_supplier() if self.actor_token_supplier else None + + token_data = { + "grant_type": "token_exchange", + "client_id": client_info.client_id, + "subject_token": subject_token, + "subject_token_type": self.subject_token_type, + } + + if client_info.client_secret: + token_data["client_secret"] = client_info.client_secret + + if actor_token: + token_data["actor_token"] = actor_token + if self.actor_token_type: + token_data["actor_token_type"] = self.actor_token_type + if self.audience: + token_data["audience"] = self.audience + if self.resource: + token_data["resource"] = self.resource + if self.client_metadata.scope: + token_data["scope"] = self.client_metadata.scope + + async with httpx.AsyncClient() as client: + response = await client.post( + token_url, + data=token_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30.0, + ) + + if response.status_code != 200: + raise Exception(f"Token request failed: {response.status_code} {response.text}") + + token_response = OAuthToken.model_validate(response.json()) + await self._validate_token_scopes(token_response) + + if token_response.expires_in: + self._token_expiry_time = time.time() + token_response.expires_in + else: + self._token_expiry_time = None + + await self.storage.set_tokens(token_response) + self._current_tokens = token_response diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 63b09133f..167e78c6e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -173,9 +173,11 @@ async def _handle_sse_event( session_message = SessionMessage(message) await read_stream_writer.send(session_message) - # Call resumption token callback if we have an ID - if sse.id and resumption_callback: - await resumption_callback(sse.id) + # Call resumption token callback if we have an ID. Only update + # the resumption token on notifications to avoid overwriting it + # with the token from the final response. + if sse.id and resumption_callback and not isinstance(message.root, JSONRPCResponse | JSONRPCError): + await resumption_callback(sse.id.strip()) # If this is a response or error return True indicating completion # Otherwise, return False to continue listening @@ -221,7 +223,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" headers = self._prepare_request_headers(ctx.headers) if ctx.metadata and ctx.metadata.resumption_token: - headers[LAST_EVENT_ID] = ctx.metadata.resumption_token + headers[LAST_EVENT_ID] = ctx.metadata.resumption_token.strip() else: raise ResumptionError("Resumption request requires a resumption token") diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index e6d99e66d..b211e238f 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -68,11 +68,22 @@ async def handle(self, request: Request) -> Response: ), status_code=400, ) - if set(client_metadata.grant_types) != {"authorization_code", "refresh_token"}: + grant_types_set: set[str] = set(client_metadata.grant_types) + valid_sets = [ + {"authorization_code", "refresh_token"}, + {"client_credentials"}, + {"token_exchange"}, + {"client_credentials", "token_exchange"}, + ] + + if grant_types_set not in valid_sets: return PydanticJSONResponse( content=RegistrationErrorResponse( error="invalid_client_metadata", - error_description="grant_types must be authorization_code and refresh_token", + error_description=( + "grant_types must be authorization_code and refresh_token " + "or client_credentials or token exchange or client_credentials and token_exchange" + ), ), status_code=400, ) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 4e15e6265..e39b4ef1e 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -40,16 +40,39 @@ class RefreshTokenRequest(BaseModel): resource: str | None = Field(None, description="Resource indicator for the token") +class ClientCredentialsRequest(BaseModel): + # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 + grant_type: Literal["client_credentials"] + scope: str | None = Field(None, description="Optional scope parameter") + client_id: str + client_secret: str | None = None + + +class TokenExchangeRequest(BaseModel): + """RFC 8693 token exchange request.""" + + grant_type: Literal["token_exchange"] + subject_token: str = Field(..., description="Token to exchange") + subject_token_type: str = Field(..., description="Type of the subject token") + actor_token: str | None = Field(None, description="Optional actor token") + actor_token_type: str | None = Field(None, description="Type of the actor token if provided") + resource: str | None = None + audience: str | None = None + scope: str | None = None + client_id: str + client_secret: str | None = None + + class TokenRequest( RootModel[ Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, + AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest | TokenExchangeRequest, Field(discriminator="grant_type"), ] ] ): root: Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, + AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest | TokenExchangeRequest, Field(discriminator="grant_type"), ] @@ -192,10 +215,49 @@ async def handle(self, request: Request): ) ) + case ClientCredentialsRequest(): + scopes = ( + token_request.scope.split(" ") + if token_request.scope + else client_info.scope.split(" ") + if client_info.scope + else [] + ) + try: + tokens = await self.provider.exchange_client_credentials(client_info, scopes) + except TokenError as e: + return self.response( + TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + ) + + case TokenExchangeRequest(): + scopes = token_request.scope.split(" ") if token_request.scope else [] + try: + tokens = await self.provider.exchange_token( + client_info, + token_request.subject_token, + token_request.subject_token_type, + token_request.actor_token, + token_request.actor_token_type, + scopes, + token_request.audience, + token_request.resource, + ) + except TokenError as e: + return self.response( + TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + ) + case RefreshTokenRequest(): refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) if refresh_token is None or refresh_token.client_id != token_request.client_id: - # if token belongs to different client, pretend it doesn't exist + # if token belongs to a different client, pretend it doesn't exist return self.response( TokenErrorResponse( error="invalid_grant", diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index b84db89a2..e4de4ecf8 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -80,6 +80,7 @@ class AuthorizeError(Exception): "unauthorized_client", "unsupported_grant_type", "invalid_scope", + "invalid_target", ] @@ -248,6 +249,24 @@ async def exchange_refresh_token( """ ... + async def exchange_client_credentials(self, client: OAuthClientInformationFull, scopes: list[str]) -> OAuthToken: + """Exchange client credentials for an MCP access token.""" + ... + + async def exchange_token( + self, + client: OAuthClientInformationFull, + subject_token: str, + subject_token_type: str, + actor_token: str | None, + actor_token_type: str | None, + scope: list[str] | None, + audience: str | None, + resource: str | None, + ) -> OAuthToken: + """Exchange an external token for an MCP access token.""" + ... + async def load_access_token(self, token: str) -> AccessTokenT | None: """ Loads an access token by its token. diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index e4db806e7..c08113128 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -163,7 +163,12 @@ def build_metadata( scopes_supported=client_registration_options.valid_scopes, response_types_supported=["code"], response_modes_supported=None, - grant_types_supported=["authorization_code", "refresh_token"], + grant_types_supported=[ + "authorization_code", + "refresh_token", + "client_credentials", + "token_exchange", + ], token_endpoint_auth_methods_supported=["client_secret_post"], token_endpoint_auth_signing_alg_values_supported=None, service_documentation=service_documentation_url, diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 33878ee15..6ee886ad8 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -13,6 +13,7 @@ class OAuthToken(BaseModel): expires_in: int | None = None scope: str | None = None refresh_token: str | None = None + issued_token_type: str | None = None @field_validator("token_type", mode="before") @classmethod @@ -46,8 +47,15 @@ class OAuthClientMetadata(BaseModel): # client_secret_post; # ie: we do not support client_secret_basic token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post" - # grant_types: this implementation only supports authorization_code & refresh_token - grant_types: list[Literal["authorization_code", "refresh_token"]] = [ + # grant_types: this implementation supports authorization_code, refresh_token, client_credentials, & token_exchange + grant_types: list[ + Literal[ + "authorization_code", + "refresh_token", + "client_credentials", + "token_exchange", + ] + ] = [ "authorization_code", "refresh_token", ] @@ -115,8 +123,18 @@ class OAuthMetadata(BaseModel): scopes_supported: list[str] | None = None response_types_supported: list[str] = ["code"] response_modes_supported: list[Literal["query", "fragment", "form_post"]] | None = None - grant_types_supported: list[str] | None = None - token_endpoint_auth_methods_supported: list[str] | None = None + grant_types_supported: ( + list[ + Literal[ + "authorization_code", + "refresh_token", + "client_credentials", + "token_exchange", + ] + ] + | None + ) = None + token_endpoint_auth_methods_supported: list[Literal["none", "client_secret_post"]] | None = None token_endpoint_auth_signing_alg_values_supported: None = None service_documentation: AnyHttpUrl | None = None ui_locales_supported: list[str] | None = None diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b2f49fc8b..865f9c973 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -312,7 +312,10 @@ async def send_notification( message=JSONRPCMessage(jsonrpc_notification), metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, ) - await self._write_stream.send(session_message) + try: + await self._write_stream.send(session_message) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logging.debug("Discarding notification due to closed stream") async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: if isinstance(response, ErrorData): diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index affcaa276..fdfa9a6e6 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -2,16 +2,24 @@ Tests for refactored OAuth client authentication implementation. """ +import asyncio import time +from unittest.mock import AsyncMock, Mock, patch import httpx import pytest from pydantic import AnyHttpUrl, AnyUrl -from mcp.client.auth import OAuthClientProvider, PKCEParameters +from mcp.client.auth import ( + ClientCredentialsProvider, + OAuthClientProvider, + PKCEParameters, + TokenExchangeProvider, +) from mcp.shared.auth import ( OAuthClientInformationFull, OAuthClientMetadata, + OAuthMetadata, OAuthToken, ProtectedResourceMetadata, ) @@ -82,6 +90,75 @@ async def callback_handler() -> tuple[str, str | None]: ) +@pytest.fixture +def client_credentials_metadata(): + return OAuthClientMetadata( + redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")], + client_name="CC Client", + grant_types=["client_credentials"], + response_types=["code"], + scope="read write", + token_endpoint_auth_method="client_secret_post", + ) + + +@pytest.fixture +def oauth_metadata(): + return OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + registration_endpoint=AnyHttpUrl("https://auth.example.com/register"), + scopes_supported=["read", "write", "admin"], + response_types_supported=["code"], + grant_types_supported=["authorization_code", "refresh_token", "client_credentials"], + code_challenge_methods_supported=["S256"], + ) + + +@pytest.fixture +def oauth_client_info(): + return OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + client_name="Test Client", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + scope="read write", + ) + + +@pytest.fixture +def oauth_token(): + return OAuthToken( + access_token="test_access_token", + token_type="Bearer", + expires_in=3600, + refresh_token="test_refresh_token", + scope="read write", + ) + + +@pytest.fixture +async def client_credentials_provider(client_credentials_metadata, mock_storage): + return ClientCredentialsProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_credentials_metadata, + storage=mock_storage, + ) + + +@pytest.fixture +async def token_exchange_provider(client_credentials_metadata, mock_storage): + return TokenExchangeProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_credentials_metadata, + storage=mock_storage, + subject_token_supplier=lambda: asyncio.sleep(0, result="user_token"), + ) + + class TestPKCEParameters: """Test PKCE parameter generation.""" @@ -580,3 +657,84 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v await auth_flow.asend(response) except StopAsyncIteration: pass # Expected + + +class TestClientCredentialsProvider: + @pytest.mark.anyio + async def test_request_token_success( + self, + client_credentials_provider, + oauth_metadata, + oauth_client_info, + oauth_token, + ): + client_credentials_provider._metadata = oauth_metadata + client_credentials_provider._client_info = oauth_client_info + + token_json = oauth_token.model_dump(by_alias=True, mode="json") + token_json.pop("refresh_token", None) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = token_json + mock_client.post.return_value = mock_response + + await client_credentials_provider.ensure_token() + + mock_client.post.assert_called_once() + args, kwargs = mock_client.post.call_args + assert kwargs["data"]["resource"] == "https://api.example.com/v1/mcp" + assert client_credentials_provider._current_tokens.access_token == oauth_token.access_token + + @pytest.mark.anyio + async def test_async_auth_flow(self, client_credentials_provider, oauth_token): + client_credentials_provider._current_tokens = oauth_token + client_credentials_provider._token_expiry_time = time.time() + 3600 + + request = httpx.Request("GET", "https://api.example.com/data") + mock_response = Mock() + mock_response.status_code = 200 + + auth_flow = client_credentials_provider.async_auth_flow(request) + updated_request = await auth_flow.__anext__() + assert updated_request.headers["Authorization"] == f"Bearer {oauth_token.access_token}" + try: + await auth_flow.asend(mock_response) + except StopAsyncIteration: + pass + + +class TestTokenExchangeProvider: + @pytest.mark.anyio + async def test_request_token_success( + self, + token_exchange_provider, + oauth_metadata, + oauth_client_info, + oauth_token, + ): + token_exchange_provider._metadata = oauth_metadata + token_exchange_provider._client_info = oauth_client_info + + token_json = oauth_token.model_dump(by_alias=True, mode="json") + token_json.pop("refresh_token", None) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = token_json + mock_client.post.return_value = mock_response + + await token_exchange_provider.ensure_token() + + mock_client.post.assert_called_once() + args, kwargs = mock_client.post.call_args + assert kwargs["data"]["resource"] == "https://api.example.com/v1/mcp" + assert token_exchange_provider._current_tokens.access_token == oauth_token.access_token diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index d595ed022..6bdd6c7cf 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -84,7 +84,12 @@ async def client(read_stream, write_stream, scope): # - Long enough for fast operations (>10ms) # - Short enough for slow operations (<200ms) # - Not too short to avoid flakiness - async with ClientSession(read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50)) as session: + async with ClientSession( + read_stream, + write_stream, + # Increased to 150ms to avoid flakiness on slower platforms + read_timeout_seconds=timedelta(milliseconds=150), + ) as session: await session.initialize() # First call should work (fast operation) diff --git a/tests/issues/test_malformed_input.py b/tests/issues/test_malformed_input.py index 97edb651e..209bafd99 100644 --- a/tests/issues/test_malformed_input.py +++ b/tests/issues/test_malformed_input.py @@ -74,6 +74,7 @@ async def test_malformed_initialize_request_does_not_crash_server(): method="tools/call", # params=None # Missing required params ) + another_request_message = SessionMessage(message=JSONRPCMessage(another_malformed_request)) await read_send_stream.send(another_request_message) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index e4a8f3f4c..afd1866dc 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -20,6 +20,7 @@ AuthorizationParams, OAuthAuthorizationServerProvider, RefreshToken, + TokenError, construct_redirect_uri, ) from mcp.server.auth.routes import ( @@ -160,6 +161,49 @@ async def exchange_refresh_token( refresh_token=new_refresh_token, ) + async def exchange_client_credentials(self, client: OAuthClientInformationFull, scopes: list[str]) -> OAuthToken: + access_token = f"access_{secrets.token_hex(32)}" + self.tokens[access_token] = AccessToken( + token=access_token, + client_id=client.client_id, + scopes=scopes, + expires_at=int(time.time()) + 3600, + ) + return OAuthToken( + access_token=access_token, + token_type="Bearer", + expires_in=3600, + scope=" ".join(scopes), + ) + + async def exchange_token( + self, + client: OAuthClientInformationFull, + subject_token: str, + subject_token_type: str, + actor_token: str | None, + actor_token_type: str | None, + scope: list[str] | None, + audience: str | None, + resource: str | None, + ) -> OAuthToken: + if subject_token == "bad_token": + raise TokenError("invalid_grant", "invalid subject token") + + access_token = f"exchanged_{secrets.token_hex(32)}" + self.tokens[access_token] = AccessToken( + token=access_token, + client_id=client.client_id, + scopes=scope or ["read"], + expires_at=int(time.time()) + 3600, + ) + return OAuthToken( + access_token=access_token, + token_type="Bearer", + expires_in=3600, + scope=" ".join(scope or ["read"]), + ) + async def load_access_token(self, token: str) -> AccessToken | None: token_info = self.tokens.get(token) @@ -354,6 +398,8 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): assert metadata["grant_types_supported"] == [ "authorization_code", "refresh_token", + "client_credentials", + "token_exchange", ] assert metadata["service_documentation"] == "https://docs.example.com/" @@ -928,7 +974,28 @@ async def test_client_registration_invalid_grant_type(self, test_client: httpx.A error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" - assert error_data["error_description"] == "grant_types must be authorization_code and refresh_token" + assert error_data["error_description"] == ( + "grant_types must be authorization_code and refresh_token " + "or client_credentials or token exchange or " + "client_credentials and token_exchange" + ) + + @pytest.mark.anyio + async def test_client_registration_client_credentials(self, test_client: httpx.AsyncClient): + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "CC Client", + "grant_types": ["client_credentials"], + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + + assert response.status_code == 201, response.content + client_info = response.json() + assert client_info["grant_types"] == ["client_credentials"] class TestAuthorizeEndpointErrors: @@ -1201,3 +1268,102 @@ async def test_authorize_invalid_scope(self, test_client: httpx.AsyncClient, reg # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state" + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [{"grant_types": ["client_credentials"]}], + indirect=True, + ) + async def test_client_credentials_token(self, test_client: httpx.AsyncClient, registered_client): + response = await test_client.post( + "/token", + data={ + "grant_type": "client_credentials", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "scope": "read write", + }, + ) + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + + @pytest.mark.anyio + async def test_metadata_includes_token_exchange(self, test_client: httpx.AsyncClient): + response = await test_client.get("/.well-known/oauth-authorization-server") + assert response.status_code == 200 + metadata = response.json() + assert "token_exchange" in metadata["grant_types_supported"] + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [{"grant_types": ["token_exchange"]}], + indirect=True, + ) + async def test_token_exchange_success(self, test_client: httpx.AsyncClient, registered_client): + response = await test_client.post( + "/token", + data={ + "grant_type": "token_exchange", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "subject_token": "good_token", + "subject_token_type": "access_token", + }, + ) + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [{"grant_types": ["token_exchange"]}], + indirect=True, + ) + async def test_token_exchange_invalid_subject(self, test_client: httpx.AsyncClient, registered_client): + response = await test_client.post( + "/token", + data={ + "grant_type": "token_exchange", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "subject_token": "bad_token", + "subject_token_type": "access_token", + }, + ) + assert response.status_code == 400 + data = response.json() + assert data["error"] == "invalid_grant" + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [{"grant_types": ["client_credentials", "token_exchange"]}], + indirect=True, + ) + async def test_client_credentials_and_token_exchange(self, test_client: httpx.AsyncClient, registered_client): + cc_response = await test_client.post( + "/token", + data={ + "grant_type": "client_credentials", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "scope": "read write", + }, + ) + assert cc_response.status_code == 200 + + te_response = await test_client.post( + "/token", + data={ + "grant_type": "token_exchange", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "subject_token": "good_token", + "subject_token_type": "access_token", + }, + ) + assert te_response.status_code == 200 diff --git a/tests/server/fastmcp/resources/test_file_resources.py b/tests/server/fastmcp/resources/test_file_resources.py index ec3c85d8d..1ff9a3cb5 100644 --- a/tests/server/fastmcp/resources/test_file_resources.py +++ b/tests/server/fastmcp/resources/test_file_resources.py @@ -100,18 +100,21 @@ async def test_missing_file_error(self, temp_file: Path): with pytest.raises(ValueError, match="Error reading file"): await resource.read() - @pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows") - @pytest.mark.anyio - async def test_permission_error(self, temp_file: Path): - """Test reading a file without permissions.""" - temp_file.chmod(0o000) # Remove all permissions - try: - resource = FileResource( - uri=FileUrl(temp_file.as_uri()), - name="test", - path=temp_file, - ) - with pytest.raises(ValueError, match="Error reading file"): - await resource.read() - finally: - temp_file.chmod(0o644) # Restore permissions + +@pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows") +@pytest.mark.anyio +async def test_permission_error(temp_file: Path): + """Test reading a file without permissions.""" + if os.geteuid() == 0: + pytest.skip("Permission test not reliable when running as root") + temp_file.chmod(0o000) # Remove all permissions + try: + resource = FileResource( + uri=FileUrl(temp_file.as_uri()), + name="test", + path=temp_file, + ) + with pytest.raises(ValueError, match="Error reading file"): + await resource.read() + finally: + temp_file.chmod(0o644) # Restore permissions diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 1ffcc13b0..f1ec929c1 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -7,6 +7,7 @@ import json import multiprocessing import socket +import sys import time from collections.abc import Generator from typing import Any @@ -1047,6 +1048,7 @@ async def mock_delete(self, *args, **kwargs): @pytest.mark.anyio +@pytest.mark.skipif(sys.platform == "win32", reason="Resumption unstable on Windows") async def test_streamablehttp_client_resumption(event_server): """Test client session to resume a long running tool.""" _, server_url = event_server @@ -1156,6 +1158,12 @@ async def run_tool(): assert result.content[0].type == "text" assert "Completed" in result.content[0].text + # Allow any pending notifications to be processed + for _ in range(50): + if captured_notifications: + break + await anyio.sleep(0.1) + # We should have received the remaining notifications assert len(captured_notifications) > 0