Skip to content

Commit d5ea38c

Browse files
committed
Add proxy routes implementation for OAuth endpoints
Signed-off-by: Jesse Sanford <108698+jessesanford@users.noreply.github.com>
1 parent fb54150 commit d5ea38c

File tree

2 files changed

+191
-0
lines changed

2 files changed

+191
-0
lines changed

src/mcp/server/auth/proxy/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Transparent OAuth proxy helpers (library form).
2+
3+
This sub-package turns the demo-level transparent OAuth proxy into a reusable
4+
component:
5+
6+
* create_proxy_routes(provider) – returns the Starlette routes that expose the
7+
proxy endpoints (/authorize, /revoke …).
8+
* build_proxy_server() – convenience helper that wires everything into a
9+
FastMCP instance.
10+
11+
The functions are re-exported here so users can simply::
12+
13+
from mcp.server.auth.proxy import build_proxy_server
14+
15+
"""
16+
17+
from __future__ import annotations
18+
19+
# Public re-exports
20+
from .routes import create_proxy_routes, fetch_upstream_metadata
21+
22+
__all__: list[str] = [
23+
"create_proxy_routes",
24+
"fetch_upstream_metadata",
25+
]
26+
27+
# build_proxy_server intentionally *not* imported here to avoid circular
28+
# imports with TransparentOAuthProxyProvider. Import from
29+
# `mcp.server.auth.proxy.server` when needed:
30+
# from mcp.server.auth.proxy.server import build_proxy_server

src/mcp/server/auth/proxy/routes.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# pyright: reportGeneralTypeIssues=false
2+
"""Starlette routes that implement the transparent OAuth proxy endpoints."""
3+
4+
from __future__ import annotations
5+
6+
import logging
7+
import urllib.parse
8+
from typing import Any
9+
10+
import httpx # type: ignore
11+
from starlette.requests import Request
12+
from starlette.responses import JSONResponse, RedirectResponse, Response
13+
from starlette.routing import Route
14+
15+
from mcp.server.fastmcp.utilities.logging import configure_logging
16+
17+
__all__ = ["fetch_upstream_metadata", "create_proxy_routes"]
18+
19+
logger = logging.getLogger("transparent_oauth_proxy.routes")
20+
21+
22+
# ---------------------------------------------------------------------------
23+
# Helper – fetch (or synthesise) upstream AS metadata
24+
# ---------------------------------------------------------------------------
25+
26+
27+
async def fetch_upstream_metadata( # noqa: D401
28+
upstream_base: str,
29+
upstream_authorize: str,
30+
upstream_token: str,
31+
upstream_jwks_uri: str | None = None,
32+
) -> dict[str, Any]:
33+
"""Return upstream metadata, mirroring logic from old server.py."""
34+
35+
# If explicit endpoints provided, craft a synthetic metadata object.
36+
if upstream_authorize and upstream_token:
37+
return {
38+
"issuer": upstream_base,
39+
"authorization_endpoint": upstream_authorize,
40+
"token_endpoint": upstream_token,
41+
"registration_endpoint": "/register",
42+
"jwks_uri": upstream_jwks_uri or "",
43+
}
44+
45+
# Otherwise attempt remote fetch.
46+
metadata_url = f"{upstream_base}/.well-known/oauth-authorization-server"
47+
try:
48+
async with httpx.AsyncClient() as client:
49+
r = await client.get(metadata_url, timeout=10)
50+
r.raise_for_status()
51+
return r.json()
52+
except Exception as exc: # noqa: BLE001
53+
logger.warning("Could not fetch upstream metadata (%s); using fallback.", exc)
54+
return {
55+
"issuer": "fallback",
56+
"authorization_endpoint": "/authorize",
57+
"token_endpoint": "/token",
58+
"registration_endpoint": "/register",
59+
}
60+
61+
62+
# ---------------------------------------------------------------------------
63+
# Route factory – returns Starlette Route objects
64+
# ---------------------------------------------------------------------------
65+
66+
67+
def create_proxy_routes(provider: Any) -> list[Route]: # type: ignore[valid-type]
68+
"""Create all additional proxy-specific routes.
69+
70+
The *provider* must be an instance of
71+
`TransparentOAuthProxyProvider` (duck-typed here to avoid circular imports).
72+
"""
73+
74+
configure_logging() # ensure log format if not already set
75+
76+
s = provider._s # access its validated settings (_Settings)
77+
78+
# Introduce a dedicated handler class to avoid nested closures while still
79+
# retaining the convenience of accessing validated settings via
80+
# ``self.s``. This improves introspection, simplifies debugging and makes
81+
# future extensibility (e.g. dependency injection) easier.
82+
83+
class _ProxyHandlers: # noqa: D401,E501
84+
"""Collection of async endpoints implementing the proxy logic."""
85+
86+
def __init__(self, settings: Any): # type: ignore[valid-type]
87+
self.s = settings
88+
89+
# ------------------------------------------------------------------
90+
# /.well-known/oauth-authorization-server
91+
# ------------------------------------------------------------------
92+
async def metadata(self, request: Request) -> Response: # noqa: D401
93+
logger.info("🔍 /.well-known/oauth-authorization-server endpoint accessed")
94+
95+
data = await fetch_upstream_metadata(
96+
self.s.upstream_authorize.rsplit("/", 1)[0], # base
97+
str(self.s.upstream_authorize),
98+
str(self.s.upstream_token),
99+
self.s.jwks_uri,
100+
)
101+
102+
host = request.headers.get("host", "localhost")
103+
scheme = "https" if request.url.scheme == "https" else "http"
104+
issuer = f"{scheme}://{host}"
105+
data.update(
106+
{
107+
"issuer": issuer,
108+
"authorization_endpoint": f"{issuer}/authorize",
109+
"token_endpoint": f"{issuer}/token",
110+
"registration_endpoint": f"{issuer}/register",
111+
}
112+
)
113+
return JSONResponse(data)
114+
115+
# ------------------------------------------------------------------
116+
# /register – Dynamic Client Registration stub
117+
# ------------------------------------------------------------------
118+
async def register(self, request: Request) -> Response: # noqa: D401
119+
body = await request.json()
120+
client_metadata = {
121+
"client_id": self.s.client_id,
122+
"client_secret": self.s.client_secret,
123+
"token_endpoint_auth_method": "client_secret_post" if self.s.client_secret else "none",
124+
**body,
125+
}
126+
return JSONResponse(client_metadata, status_code=201)
127+
128+
# ------------------------------------------------------------------
129+
# /authorize – Redirect to upstream with injections
130+
# ------------------------------------------------------------------
131+
async def authorize(self, request: Request) -> Response: # noqa: D401
132+
params = dict(request.query_params)
133+
params["client_id"] = self.s.client_id
134+
if "scope" not in params:
135+
params["scope"] = self.s.default_scope
136+
137+
redirect_url = f"{self.s.upstream_authorize}?{urllib.parse.urlencode(params)}"
138+
return RedirectResponse(redirect_url)
139+
140+
# ------------------------------------------------------------------
141+
# /revoke – Pass-through
142+
# ------------------------------------------------------------------
143+
async def revoke(self, request: Request) -> Response: # noqa: D401
144+
form = await request.form()
145+
data = dict(form)
146+
data.setdefault("client_id", self.s.client_id)
147+
if self.s.client_secret:
148+
data.setdefault("client_secret", self.s.client_secret)
149+
150+
async with httpx.AsyncClient() as client:
151+
r = await client.post(str(self.s.upstream_token).rsplit("/", 1)[0] + "/revoke", data=data, timeout=10)
152+
return JSONResponse(r.json(), status_code=r.status_code)
153+
154+
handlers = _ProxyHandlers(s)
155+
156+
return [
157+
Route("/.well-known/oauth-authorization-server", handlers.metadata, methods=["GET"]),
158+
Route("/register", handlers.register, methods=["POST"]),
159+
Route("/authorize", handlers.authorize, methods=["GET"]),
160+
Route("/revoke", handlers.revoke, methods=["POST"]),
161+
]

0 commit comments

Comments
 (0)