Skip to content

Support falling back to OIDC metadata for auth #1061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 101 additions & 39 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ def should_include_resource_param(self, protocol_version: str | None = None) ->
return protocol_version >= "2025-06-18"


OAuthDiscoveryStack = list[Callable[[], Awaitable[httpx.Request]]]


class OAuthClientProvider(httpx.Auth):
"""
OAuth2 authentication for httpx.
Expand Down Expand Up @@ -221,32 +224,60 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
except ValidationError:
pass

def _build_well_known_path(self, pathname: str) -> str:
def _build_well_known_path(self, pathname: str, well_known_endpoint: str) -> str:
"""Construct well-known path for OAuth metadata discovery."""
well_known_path = f"/.well-known/oauth-authorization-server{pathname}"
well_known_path = f"/.well-known/{well_known_endpoint}{pathname}"
if pathname.endswith("/"):
# Strip trailing slash from pathname to avoid double slashes
well_known_path = well_known_path[:-1]
return well_known_path

def _should_attempt_fallback(self, response_status: int, pathname: str) -> bool:
"""Determine if fallback to root discovery should be attempted."""
return response_status == 404 and pathname != "/"
def _build_well_known_fallback_url(self, well_known_endpoint: str) -> str:
"""Construct fallback well-known URL for OAuth metadata discovery in legacy servers."""
base_url = getattr(self.context, "discovery_base_url", "")
if not base_url:
raise OAuthFlowError("No base URL available for fallback discovery")

# Fallback to root discovery for legacy servers
return urljoin(base_url, f"/.well-known/{well_known_endpoint}")

def _build_oidc_fallback_path(self, pathname: str, well_known_endpoint: str) -> str:
"""Construct fallback well-known path for OIDC metadata discovery in legacy servers."""
# Strip trailing slash from pathname to avoid double slashes
clean_pathname = pathname[:-1] if pathname.endswith("/") else pathname
# OIDC 1.0 appends the well-known path to the full AS URL
return f"{clean_pathname}/.well-known/{well_known_endpoint}"

def _build_oidc_fallback_url(self, well_known_endpoint: str) -> str:
"""Construct fallback well-known URL for OIDC metadata discovery in legacy servers."""
if self.context.auth_server_url:
auth_server_url = self.context.auth_server_url
else:
auth_server_url = self.context.server_url

parsed = urlparse(auth_server_url)
well_known_path = self._build_oidc_fallback_path(parsed.path, well_known_endpoint)
base_url = f"{parsed.scheme}://{parsed.netloc}"
return urljoin(base_url, well_known_path)

def _should_attempt_fallback(self, response_status: int, discovery_stack: OAuthDiscoveryStack) -> bool:
"""Determine if further fallback should be attempted."""
return response_status == 404 and len(discovery_stack) > 0

async def _try_metadata_discovery(self, url: str) -> httpx.Request:
"""Build metadata discovery request for a specific URL."""
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})

async def _discover_oauth_metadata(self) -> httpx.Request:
"""Build OAuth metadata discovery request with fallback support."""
async def _discover_well_known_metadata(self, well_known_endpoint: str) -> httpx.Request:
"""Build .well-known metadata discovery request with fallback support."""
if self.context.auth_server_url:
auth_server_url = self.context.auth_server_url
else:
auth_server_url = self.context.server_url

# Per RFC 8414, try path-aware discovery first
parsed = urlparse(auth_server_url)
well_known_path = self._build_well_known_path(parsed.path)
well_known_path = self._build_well_known_path(parsed.path, well_known_endpoint)
base_url = f"{parsed.scheme}://{parsed.netloc}"
url = urljoin(base_url, well_known_path)

Expand All @@ -256,17 +287,37 @@ async def _discover_oauth_metadata(self) -> httpx.Request:

return await self._try_metadata_discovery(url)

async def _discover_well_known_metadata_fallback(self, well_known_endpoint: str) -> httpx.Request:
"""Build fallback OAuth metadata discovery request for legacy servers."""
url = self._build_well_known_fallback_url(well_known_endpoint)
return await self._try_metadata_discovery(url)

async def _discover_oauth_metadata(self) -> httpx.Request:
"""Build OAuth metadata discovery request with fallback support."""
return await self._discover_well_known_metadata("oauth-authorization-server")

async def _discover_oauth_metadata_fallback(self) -> httpx.Request:
"""Build fallback OAuth metadata discovery request for legacy servers."""
base_url = getattr(self.context, "discovery_base_url", "")
if not base_url:
raise OAuthFlowError("No base URL available for fallback discovery")
return await self._discover_well_known_metadata_fallback("oauth-authorization-server")

# Fallback to root discovery for legacy servers
url = urljoin(base_url, "/.well-known/oauth-authorization-server")
async def _discover_oidc_metadata(self) -> httpx.Request:
"""
Build fallback OIDC metadata discovery request.
See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
"""
return await self._discover_well_known_metadata("openid-configuration")

async def _discover_oidc_metadata_fallback(self) -> httpx.Request:
"""
Build fallback OIDC metadata discovery request for legacy servers.
See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
"""
url = self._build_oidc_fallback_url("openid-configuration")
return await self._try_metadata_discovery(url)

async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fallback: bool = False) -> bool:
async def _handle_oauth_metadata_response(
self, response: httpx.Response, discovery_stack: OAuthDiscoveryStack
) -> bool:
"""Handle OAuth metadata response. Returns True if handled successfully."""
if response.status_code == 200:
try:
Expand All @@ -280,13 +331,10 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fal
except ValidationError:
pass

# Check if we should attempt fallback (404 on path-aware discovery)
if not is_fallback and self._should_attempt_fallback(
response.status_code, getattr(self.context, "discovery_pathname", "/")
):
return False # Signal that fallback should be attempted

return True # Signal no fallback needed (either success or non-404 error)
# Check if we should attempt fallback
# True: No fallback needed (either success or non-404 error)
# False: Signal that fallback should be attempted
return not self._should_attempt_fallback(response.status_code, discovery_stack)

async def _register_client(self) -> httpx.Request | None:
"""Build registration request or skip if already registered."""
Expand Down Expand Up @@ -480,6 +528,26 @@ def _add_auth_header(self, request: httpx.Request) -> None:
if self.context.current_tokens and self.context.current_tokens.access_token:
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"

def _create_oauth_discovery_stack(self) -> OAuthDiscoveryStack:
"""Create a stack of attempts to discover OAuth metadata."""
discovery_attempts: OAuthDiscoveryStack = [
# Start with path-aware OAuth discovery
self._discover_oauth_metadata,
# If path-aware discovery fails with 404, try fallback to root
self._discover_oauth_metadata_fallback,
# If root discovery fails with 404, fall back to OIDC 1.0 following
# RFC 8414 path-aware semantics (see RFC 8414 section 5)
self._discover_oidc_metadata,
# If path-aware OIDC discovery failed with 404, fall back to OIDC 1.0
# following OIDC 1.0 semantics (see RFC 8414 section 5)
self._discover_oidc_metadata_fallback,
]

# Reverse the list so we can call pop() without remembering we declared
# this stack backwards for readability
discovery_attempts.reverse()
return discovery_attempts

async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
async with self.context.lock:
Expand All @@ -499,15 +567,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
await self._handle_protected_resource_response(discovery_response)

# Step 2: Discover OAuth metadata (with fallback for legacy servers)
oauth_request = await self._discover_oauth_metadata()
oauth_response = yield oauth_request
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)

# If path-aware discovery failed with 404, try fallback to root
if not handled:
fallback_request = await self._discover_oauth_metadata_fallback()
fallback_response = yield fallback_request
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
oauth_discovery_stack = self._create_oauth_discovery_stack()
while len(oauth_discovery_stack) > 0:
oauth_discovery = oauth_discovery_stack.pop()
oauth_request = await oauth_discovery()
oauth_response = yield oauth_request
await self._handle_oauth_metadata_response(oauth_response, oauth_discovery_stack)

# Step 3: Register client if needed
registration_request = await self._register_client()
Expand Down Expand Up @@ -551,15 +616,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
await self._handle_protected_resource_response(discovery_response)

# Step 2: Discover OAuth metadata (with fallback for legacy servers)
oauth_request = await self._discover_oauth_metadata()
oauth_response = yield oauth_request
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)

# If path-aware discovery failed with 404, try fallback to root
if not handled:
fallback_request = await self._discover_oauth_metadata_fallback()
fallback_response = yield fallback_request
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
oauth_discovery_stack = self._create_oauth_discovery_stack()
while len(oauth_discovery_stack) > 0:
oauth_discovery = oauth_discovery_stack.pop()
oauth_request = await oauth_discovery()
oauth_response = yield oauth_request
await self._handle_oauth_metadata_response(oauth_response, oauth_discovery_stack)

# Step 3: Register client if needed
registration_request = await self._register_client()
Expand Down
Loading
Loading