Skip to content

Commit 73816d1

Browse files
committed
Add Transparent Oauth Proxy Provider
Signed-off-by: Jesse Sanford <108698+jessesanford@users.noreply.github.com>
1 parent 595d5ca commit 73816d1

File tree

1 file changed

+370
-0
lines changed

1 file changed

+370
-0
lines changed
Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportAttributeAccessIssue=false, reportUnknownArgumentType=false, reportCallIssue=false, reportUnnecessaryIsInstance=false
2+
from __future__ import annotations
3+
4+
import logging
5+
import os
6+
import time
7+
import uuid
8+
from collections.abc import Mapping
9+
from typing import Any, cast
10+
from urllib.parse import urlencode
11+
12+
import httpx # type: ignore
13+
from pydantic import AnyHttpUrl, AnyUrl, Field
14+
from pydantic_settings import BaseSettings, SettingsConfigDict
15+
from starlette.responses import Response
16+
from starlette.routing import Route
17+
18+
from mcp.server.auth.handlers.token import TokenHandler
19+
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
20+
from mcp.server.auth.provider import (
21+
AccessToken,
22+
AuthorizationCode,
23+
AuthorizationParams,
24+
OAuthAuthorizationServerProvider,
25+
)
26+
from mcp.server.auth.proxy.routes import create_proxy_routes
27+
from mcp.server.auth.routes import create_auth_routes
28+
from mcp.server.auth.settings import ClientRegistrationOptions
29+
from mcp.server.fastmcp.utilities.logging import redact_sensitive_data
30+
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
31+
32+
"""Transparent OAuth proxy provider for FastMCP (Anthropic SDK).
33+
34+
This provider mimics the behaviour of fastapi_mcp's `setup_proxies=True` and the
35+
`TransparentOAuthProxyProvider` from the `fastmcp` fork. It forwards all real
36+
OAuth traffic (authorize / token / jwks) to an upstream Authorization Server
37+
(AS) while *locally* implementing Dynamic Client Registration so that MCP
38+
clients such as Cursor can register even when the upstream AS disables RFC 7591
39+
registration.
40+
41+
Environment variables (all optional – if omitted fall back to sensible defaults
42+
or raise clearly):
43+
44+
UPSTREAM_AUTHORIZATION_ENDPOINT Full URL of the upstream `/authorize` endpoint
45+
UPSTREAM_TOKEN_ENDPOINT Full URL of the upstream `/token` endpoint
46+
UPSTREAM_JWKS_URI URL of the upstream JWKS (optional, not yet used)
47+
UPSTREAM_CLIENT_ID Fixed client_id registered with the upstream
48+
UPSTREAM_CLIENT_SECRET Fixed secret (omit for public client)
49+
50+
PROXY_DEFAULT_SCOPE Space-separated default scope (default: "openid")
51+
52+
A simple helper ``TransparentOAuthProxyProvider.from_env()`` reads these vars.
53+
"""
54+
55+
__all__ = ["TransparentOAuthProxyProvider"]
56+
57+
logger = logging.getLogger("transparent_oauth_proxy")
58+
59+
60+
class ProxyTokenHandler(TokenHandler):
61+
"""Token handler that simply proxies token requests to the upstream AS.
62+
63+
We intentionally bypass redirect_uri and PKCE checks that the normal
64+
``TokenHandler`` performs because in *transparent proxy* mode we do not
65+
have enough information locally. Instead of validating, we forward the
66+
form untouched to the upstream token endpoint and stream the response
67+
back to the caller.
68+
"""
69+
70+
def __init__(self, provider: TransparentOAuthProxyProvider):
71+
# We provide a dummy ClientAuthenticator that will accept any client –
72+
# we are not going to invoke the base-class logic anyway.
73+
super().__init__(provider=provider, client_authenticator=ClientAuthenticator(provider))
74+
self.provider = provider # keep for easy access
75+
76+
async def handle(self, request) -> Response: # type: ignore[override]
77+
correlation_id = str(uuid.uuid4())[:8]
78+
start_time = time.time()
79+
80+
logger.info(f"[{correlation_id}] 🔄 ProxyTokenHandler - passthrough")
81+
82+
try:
83+
form = await request.form()
84+
form_dict = dict(form)
85+
86+
redacted_form = redact_sensitive_data(form_dict)
87+
logger.info(f"[{correlation_id}] ➡︎ Incoming form: {redacted_form}")
88+
89+
headers = {
90+
"Content-Type": "application/x-www-form-urlencoded",
91+
"Accept": "application/json",
92+
"User-Agent": "MCP-TransparentProxy/1.0",
93+
}
94+
95+
http = self.provider.http_client
96+
logger.info(f"[{correlation_id}] ⮕ Forwarding to {self.provider._s.upstream_token}")
97+
upstream_resp = await http.post(str(self.provider._s.upstream_token), data=form_dict, headers=headers)
98+
99+
except httpx.HTTPError as exc:
100+
logger.error(f"[{correlation_id}] ✗ Upstream HTTP error: {exc}")
101+
return Response(
102+
content='{"error":"server_error","error_description":"Upstream server error"}',
103+
status_code=502,
104+
headers={"Content-Type": "application/json"},
105+
)
106+
except Exception as exc:
107+
logger.error(f"[{correlation_id}] ✗ Unexpected proxy error: {exc}")
108+
return Response(
109+
content='{"error":"server_error"}',
110+
status_code=500,
111+
headers={"Content-Type": "application/json"},
112+
)
113+
114+
finally:
115+
elapsed = time.time() - start_time
116+
logger.info(f"[{correlation_id}] ⏱ Finished in {elapsed:.2f}s")
117+
118+
# Log upstream response (redacted)
119+
try:
120+
if upstream_resp.headers.get("content-type", "").startswith("application/json"):
121+
body = upstream_resp.json()
122+
logger.info(
123+
f"[{correlation_id}] ⬅︎ Body: {redact_sensitive_data(body) if isinstance(body, dict) else body}"
124+
)
125+
except Exception:
126+
pass
127+
128+
return Response(
129+
content=upstream_resp.content,
130+
status_code=upstream_resp.status_code,
131+
headers=dict(upstream_resp.headers),
132+
)
133+
134+
135+
class ProxySettings(BaseSettings):
136+
"""Validated environment-driven settings for the transparent OAuth proxy."""
137+
138+
model_config = SettingsConfigDict(env_file=".env", populate_by_name=True, extra="ignore")
139+
140+
upstream_authorize: AnyHttpUrl = Field(..., alias="UPSTREAM_AUTHORIZATION_ENDPOINT")
141+
upstream_token: AnyHttpUrl = Field(..., alias="UPSTREAM_TOKEN_ENDPOINT")
142+
jwks_uri: str | None = Field(None, alias="UPSTREAM_JWKS_URI")
143+
144+
client_id: str | None = Field(None, alias="UPSTREAM_CLIENT_ID")
145+
client_secret: str | None = Field(None, alias="UPSTREAM_CLIENT_SECRET")
146+
147+
# Allow overriding via env var, but default to "openid" if not provided
148+
default_scope: str = Field("openid", alias="PROXY_DEFAULT_SCOPE")
149+
150+
@classmethod
151+
def load(cls) -> ProxySettings:
152+
"""Instantiate settings from environment variables (for backwards compatibility)."""
153+
return cls()
154+
155+
156+
# Backwards-compatibility alias – existing callers/tests import `_Settings`
157+
_Settings = ProxySettings # type: ignore
158+
159+
160+
class TransparentOAuthProxyProvider(OAuthAuthorizationServerProvider[AuthorizationCode, Any, AccessToken]):
161+
"""Minimal pass-through provider – only implements code flow, no refresh."""
162+
163+
def __init__(self, *, settings: ProxySettings):
164+
# Fill in client_id fallback if not provided via upstream var
165+
if settings.client_id is None:
166+
settings.client_id = os.getenv("PROXY_CLIENT_ID", "demo-client-id") # type: ignore[assignment]
167+
assert settings.client_id is not None, "client_id must be provided"
168+
self._s = settings
169+
# simple in-memory auth-code store (maps code→AuthorizationCode)
170+
self._codes: dict[str, AuthorizationCode] = {}
171+
# always the same client info returned by /register
172+
self._static_client = OAuthClientInformationFull(
173+
client_id=str(self._s.client_id),
174+
client_secret=self._s.client_secret,
175+
redirect_uris=[cast(AnyUrl, cast(object, "http://localhost"))],
176+
grant_types=["authorization_code"],
177+
token_endpoint_auth_method="none" if self._s.client_secret is None else "client_secret_post",
178+
)
179+
180+
# Single reusable HTTP client for communicating with the upstream AS
181+
self._http: httpx.AsyncClient = httpx.AsyncClient(timeout=15)
182+
183+
# Expose http client for handlers
184+
@property
185+
def http_client(self) -> httpx.AsyncClient: # noqa: D401
186+
return self._http
187+
188+
async def aclose(self) -> None:
189+
"""Close the underlying HTTP client."""
190+
await self._http.aclose()
191+
192+
# ---------------------------------------------------------------------
193+
# Dynamic Client Registration – always enabled
194+
# ---------------------------------------------------------------------
195+
196+
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: # noqa: D401
197+
return self._static_client if client_id == self._s.client_id else None
198+
199+
async def register_client(self, client_info: OAuthClientInformationFull) -> None: # noqa: D401
200+
"""Spoof DCR: overwrite the incoming info with fixed credentials."""
201+
202+
client_info.client_id = str(self._s.client_id)
203+
client_info.client_secret = self._s.client_secret
204+
# Ensure token_endpoint_auth_method reflects whether secret exists
205+
client_info.token_endpoint_auth_method = "none" if self._s.client_secret is None else "client_secret_post"
206+
# Replace stored static client redirect URIs with provided ones so later validation passes
207+
self._static_client.redirect_uris = client_info.redirect_uris
208+
return None
209+
210+
# ------------------------------------------------------------------
211+
# Authorization endpoint – redirect to upstream
212+
# ------------------------------------------------------------------
213+
214+
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: # noqa: D401
215+
query: dict[str, str | None] = {
216+
"response_type": "code",
217+
"client_id": str(self._s.client_id),
218+
"redirect_uri": str(params.redirect_uri),
219+
"code_challenge": params.code_challenge,
220+
"code_challenge_method": "S256",
221+
"scope": " ".join(params.scopes or [self._s.default_scope]),
222+
"state": params.state,
223+
}
224+
return f"{self._s.upstream_authorize}?{urlencode({k: v for k, v in query.items() if v})}"
225+
226+
# ------------------------------------------------------------------
227+
# Auth-code tracking / exchange
228+
# ------------------------------------------------------------------
229+
230+
async def load_authorization_code(
231+
self, client: OAuthClientInformationFull, authorization_code: str
232+
) -> AuthorizationCode | None: # noqa: D401,E501
233+
# create lightweight object; we cannot verify with upstream at this stage
234+
return AuthorizationCode(
235+
code=authorization_code,
236+
scopes=[self._s.default_scope],
237+
expires_at=int(time.time() + 300),
238+
client_id=str(self._s.client_id),
239+
redirect_uri=cast(AnyUrl, cast(object, "http://localhost")), # type: ignore[arg-type]
240+
redirect_uri_provided_explicitly=False,
241+
code_challenge="", # not validated here
242+
)
243+
244+
async def exchange_authorization_code(
245+
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
246+
) -> OAuthToken: # noqa: D401,E501
247+
# Generate correlation ID for this request
248+
correlation_id = str(uuid.uuid4())[:8]
249+
start_time = time.time()
250+
251+
logger.info(f"[{correlation_id}] Starting token exchange for client_id={client.client_id}")
252+
253+
data: dict[str, str] = {
254+
"grant_type": "authorization_code",
255+
"client_id": str(self._s.client_id),
256+
"code": authorization_code.code,
257+
"redirect_uri": str(authorization_code.redirect_uri),
258+
}
259+
if self._s.client_secret:
260+
data["client_secret"] = self._s.client_secret
261+
262+
# Log outgoing request with full details
263+
redacted_data = redact_sensitive_data(data)
264+
logger.info(f"[{correlation_id}] ⮕ Preparing upstream token request")
265+
logger.info(f"[{correlation_id}] ⮕ Target URL: {self._s.upstream_token}")
266+
logger.info(f"[{correlation_id}] ⮕ Request data: {redacted_data}")
267+
268+
headers = {
269+
"Content-Type": "application/x-www-form-urlencoded",
270+
"Accept": "application/json",
271+
"User-Agent": "MCP-TransparentProxy/1.0",
272+
}
273+
logger.info(f"[{correlation_id}] ⮕ Request headers: {headers}")
274+
275+
http = self.http_client
276+
try:
277+
logger.info(f"[{correlation_id}] ⮕ Sending POST request to upstream")
278+
resp = await http.post(str(self._s.upstream_token), data=data, headers=headers)
279+
280+
elapsed_time = time.time() - start_time
281+
logger.info(f"[{correlation_id}] ⬅︎ Upstream response received in {elapsed_time:.2f}s")
282+
logger.info(f"[{correlation_id}] ⬅︎ Status: {resp.status_code}")
283+
logger.info(f"[{correlation_id}] ⬅︎ Headers: {dict(resp.headers)}")
284+
285+
# Log response body (redacted)
286+
try:
287+
body = resp.json()
288+
redacted_body = redact_sensitive_data(body) if isinstance(body, dict) else body
289+
logger.info(f"[{correlation_id}] ⬅︎ Response body: {redacted_body}")
290+
except Exception as e:
291+
logger.warning(f"[{correlation_id}] ⬅︎ Could not parse response as JSON: {e}")
292+
logger.info(f"[{correlation_id}] ⬅︎ Raw response: {resp.text[:500]}...")
293+
294+
resp.raise_for_status()
295+
296+
except httpx.HTTPError as e:
297+
logger.error(f"[{correlation_id}] ⬅︎ HTTP error occurred: {e}")
298+
raise
299+
except Exception as e:
300+
logger.error(f"[{correlation_id}] ⬅︎ Unexpected error: {e}")
301+
raise
302+
303+
body: Mapping[str, Any] = resp.json()
304+
logger.info(f"[{correlation_id}] ✓ Token exchange completed successfully")
305+
return OAuthToken(**body) # type: ignore[arg-type]
306+
307+
# ------------------------------------------------------------------
308+
# Unused grant types
309+
# ------------------------------------------------------------------
310+
311+
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str): # noqa: D401
312+
return None
313+
314+
async def exchange_refresh_token(
315+
self,
316+
client: OAuthClientInformationFull,
317+
refresh_token: str,
318+
scopes: list[str],
319+
) -> OAuthToken: # noqa: D401
320+
raise NotImplementedError
321+
322+
async def load_access_token(self, token: str) -> AccessToken | None: # noqa: D401
323+
# For now we cannot validate JWT; return a dummy AccessToken so BearerAuth passes.
324+
return AccessToken(
325+
token=token, client_id=str(self._s.client_id), scopes=[self._s.default_scope], expires_at=None
326+
)
327+
328+
async def revoke_token(self, token: object) -> None: # noqa: D401
329+
return None
330+
331+
# ------------------------------------------------------------------
332+
# Helpers
333+
# ------------------------------------------------------------------
334+
335+
@classmethod
336+
def from_env(cls) -> TransparentOAuthProxyProvider:
337+
"""Construct provider using :class:`ProxySettings` populated from the environment."""
338+
return cls(settings=ProxySettings.load())
339+
340+
# FastMCP will read `client_registration_options` to decide whether to expose /register
341+
@property
342+
def client_registration_options(self) -> ClientRegistrationOptions: # type: ignore[override]
343+
return ClientRegistrationOptions(enabled=True)
344+
345+
# ------------------------------------------------------------------
346+
# Provide custom auth routes so that our proxy /token endpoint overrides the default one
347+
# ------------------------------------------------------------------
348+
349+
def get_auth_routes(self): # type: ignore[override]
350+
"""Return full auth+proxy route list for FastMCP."""
351+
352+
routes = create_auth_routes(
353+
provider=self,
354+
issuer_url=AnyHttpUrl("http://localhost:8000"), # placeholder; FastMCP rewrites host
355+
client_registration_options=self.client_registration_options,
356+
revocation_options=None,
357+
service_documentation_url=None,
358+
)
359+
360+
# Drop default /token and /authorize handlers – we provide custom ones.
361+
routes = [r for r in routes if not (isinstance(r, Route) and r.path in {"/token", "/authorize"})]
362+
363+
# Insert proxy /token handler first for high precedence
364+
proxy_handler = ProxyTokenHandler(self)
365+
routes.insert(0, Route("/token", endpoint=proxy_handler.handle, methods=["POST"]))
366+
367+
# Append additional proxy endpoints (metadata, register, authorize, revoke…)
368+
routes.extend(create_proxy_routes(self))
369+
370+
return routes

0 commit comments

Comments
 (0)