|
| 1 | +# pyright: reportGeneralTypeIssues=false |
| 2 | +"""Starlette routes that implement the transparent OAuth proxy endpoints.""" |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +import logging |
| 7 | +import urllib.parse |
| 8 | +from typing import Any |
| 9 | + |
| 10 | +import httpx # type: ignore |
| 11 | +from starlette.requests import Request |
| 12 | +from starlette.responses import JSONResponse, RedirectResponse, Response |
| 13 | +from starlette.routing import Route |
| 14 | + |
| 15 | +from mcp.server.fastmcp.utilities.logging import configure_logging |
| 16 | + |
| 17 | +__all__ = ["fetch_upstream_metadata", "create_proxy_routes"] |
| 18 | + |
| 19 | +logger = logging.getLogger("transparent_oauth_proxy.routes") |
| 20 | + |
| 21 | + |
| 22 | +# --------------------------------------------------------------------------- |
| 23 | +# Helper – fetch (or synthesise) upstream AS metadata |
| 24 | +# --------------------------------------------------------------------------- |
| 25 | + |
| 26 | + |
| 27 | +async def fetch_upstream_metadata( # noqa: D401 |
| 28 | + upstream_base: str, |
| 29 | + upstream_authorize: str, |
| 30 | + upstream_token: str, |
| 31 | + upstream_jwks_uri: str | None = None, |
| 32 | +) -> dict[str, Any]: |
| 33 | + """Return upstream metadata, mirroring logic from old server.py.""" |
| 34 | + |
| 35 | + # If explicit endpoints provided, craft a synthetic metadata object. |
| 36 | + if upstream_authorize and upstream_token: |
| 37 | + return { |
| 38 | + "issuer": upstream_base, |
| 39 | + "authorization_endpoint": upstream_authorize, |
| 40 | + "token_endpoint": upstream_token, |
| 41 | + "registration_endpoint": "/register", |
| 42 | + "jwks_uri": upstream_jwks_uri or "", |
| 43 | + } |
| 44 | + |
| 45 | + # Otherwise attempt remote fetch. |
| 46 | + metadata_url = f"{upstream_base}/.well-known/oauth-authorization-server" |
| 47 | + try: |
| 48 | + async with httpx.AsyncClient() as client: |
| 49 | + r = await client.get(metadata_url, timeout=10) |
| 50 | + r.raise_for_status() |
| 51 | + return r.json() |
| 52 | + except Exception as exc: # noqa: BLE001 |
| 53 | + logger.warning("Could not fetch upstream metadata (%s); using fallback.", exc) |
| 54 | + return { |
| 55 | + "issuer": "fallback", |
| 56 | + "authorization_endpoint": "/authorize", |
| 57 | + "token_endpoint": "/token", |
| 58 | + "registration_endpoint": "/register", |
| 59 | + } |
| 60 | + |
| 61 | + |
| 62 | +# --------------------------------------------------------------------------- |
| 63 | +# Route factory – returns Starlette Route objects |
| 64 | +# --------------------------------------------------------------------------- |
| 65 | + |
| 66 | + |
| 67 | +def create_proxy_routes(provider: Any) -> list[Route]: # type: ignore[valid-type] |
| 68 | + """Create all additional proxy-specific routes. |
| 69 | +
|
| 70 | + The *provider* must be an instance of |
| 71 | + `TransparentOAuthProxyProvider` (duck-typed here to avoid circular imports). |
| 72 | + """ |
| 73 | + |
| 74 | + configure_logging() # ensure log format if not already set |
| 75 | + |
| 76 | + s = provider._s # access its validated settings (_Settings) |
| 77 | + |
| 78 | + # Introduce a dedicated handler class to avoid nested closures while still |
| 79 | + # retaining the convenience of accessing validated settings via |
| 80 | + # ``self.s``. This improves introspection, simplifies debugging and makes |
| 81 | + # future extensibility (e.g. dependency injection) easier. |
| 82 | + |
| 83 | + class _ProxyHandlers: # noqa: D401,E501 |
| 84 | + """Collection of async endpoints implementing the proxy logic.""" |
| 85 | + |
| 86 | + def __init__(self, settings: Any): # type: ignore[valid-type] |
| 87 | + self.s = settings |
| 88 | + |
| 89 | + # ------------------------------------------------------------------ |
| 90 | + # /.well-known/oauth-authorization-server |
| 91 | + # ------------------------------------------------------------------ |
| 92 | + async def metadata(self, request: Request) -> Response: # noqa: D401 |
| 93 | + logger.info("🔍 /.well-known/oauth-authorization-server endpoint accessed") |
| 94 | + |
| 95 | + data = await fetch_upstream_metadata( |
| 96 | + self.s.upstream_authorize.rsplit("/", 1)[0], # base |
| 97 | + str(self.s.upstream_authorize), |
| 98 | + str(self.s.upstream_token), |
| 99 | + self.s.jwks_uri, |
| 100 | + ) |
| 101 | + |
| 102 | + host = request.headers.get("host", "localhost") |
| 103 | + scheme = "https" if request.url.scheme == "https" else "http" |
| 104 | + issuer = f"{scheme}://{host}" |
| 105 | + data.update( |
| 106 | + { |
| 107 | + "issuer": issuer, |
| 108 | + "authorization_endpoint": f"{issuer}/authorize", |
| 109 | + "token_endpoint": f"{issuer}/token", |
| 110 | + "registration_endpoint": f"{issuer}/register", |
| 111 | + } |
| 112 | + ) |
| 113 | + return JSONResponse(data) |
| 114 | + |
| 115 | + # ------------------------------------------------------------------ |
| 116 | + # /register – Dynamic Client Registration stub |
| 117 | + # ------------------------------------------------------------------ |
| 118 | + async def register(self, request: Request) -> Response: # noqa: D401 |
| 119 | + body = await request.json() |
| 120 | + client_metadata = { |
| 121 | + "client_id": self.s.client_id, |
| 122 | + "client_secret": self.s.client_secret, |
| 123 | + "token_endpoint_auth_method": "client_secret_post" if self.s.client_secret else "none", |
| 124 | + **body, |
| 125 | + } |
| 126 | + return JSONResponse(client_metadata, status_code=201) |
| 127 | + |
| 128 | + # ------------------------------------------------------------------ |
| 129 | + # /authorize – Redirect to upstream with injections |
| 130 | + # ------------------------------------------------------------------ |
| 131 | + async def authorize(self, request: Request) -> Response: # noqa: D401 |
| 132 | + params = dict(request.query_params) |
| 133 | + params["client_id"] = self.s.client_id |
| 134 | + if "scope" not in params: |
| 135 | + params["scope"] = self.s.default_scope |
| 136 | + |
| 137 | + redirect_url = f"{self.s.upstream_authorize}?{urllib.parse.urlencode(params)}" |
| 138 | + return RedirectResponse(redirect_url) |
| 139 | + |
| 140 | + # ------------------------------------------------------------------ |
| 141 | + # /revoke – Pass-through |
| 142 | + # ------------------------------------------------------------------ |
| 143 | + async def revoke(self, request: Request) -> Response: # noqa: D401 |
| 144 | + form = await request.form() |
| 145 | + data = dict(form) |
| 146 | + data.setdefault("client_id", self.s.client_id) |
| 147 | + if self.s.client_secret: |
| 148 | + data.setdefault("client_secret", self.s.client_secret) |
| 149 | + |
| 150 | + async with httpx.AsyncClient() as client: |
| 151 | + r = await client.post(str(self.s.upstream_token).rsplit("/", 1)[0] + "/revoke", data=data, timeout=10) |
| 152 | + return JSONResponse(r.json(), status_code=r.status_code) |
| 153 | + |
| 154 | + handlers = _ProxyHandlers(s) |
| 155 | + |
| 156 | + return [ |
| 157 | + Route("/.well-known/oauth-authorization-server", handlers.metadata, methods=["GET"]), |
| 158 | + Route("/register", handlers.register, methods=["POST"]), |
| 159 | + Route("/authorize", handlers.authorize, methods=["GET"]), |
| 160 | + Route("/revoke", handlers.revoke, methods=["POST"]), |
| 161 | + ] |
0 commit comments