|
| 1 | +# pyright: reportMissingImports=false |
| 2 | +import base64 |
| 3 | +import json |
| 4 | +import logging |
| 5 | +import os |
| 6 | +import time |
| 7 | +from typing import Any |
| 8 | + |
| 9 | +from dotenv import load_dotenv # type: ignore |
| 10 | +from mcp.server.auth.providers.transparent_proxy import ProxySettings # type: ignore |
| 11 | +from mcp.server.auth.proxy.server import build_proxy_server # noqa: E402 |
| 12 | +from mcp.server.fastmcp.server import Context |
| 13 | +from starlette.requests import Request # type: ignore |
| 14 | + |
| 15 | +# Load environment variables from .env if present |
| 16 | +load_dotenv() |
| 17 | + |
| 18 | +# Configure logging after .env so LOG_LEVEL can come from environment |
| 19 | +LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() |
| 20 | + |
| 21 | +logging.basicConfig( |
| 22 | + level=LOG_LEVEL, |
| 23 | + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| 24 | + datefmt="%Y-%m-%d %H:%M:%S", |
| 25 | +) |
| 26 | + |
| 27 | +# Dedicated logger for this server module |
| 28 | +logger = logging.getLogger("proxy_oauth.server") |
| 29 | + |
| 30 | +# Suppress noisy INFO messages from the FastMCP low-level server unless we are |
| 31 | +# explicitly running in DEBUG mode. These logs (e.g. "Processing request of type |
| 32 | +# ListToolsRequest") are helpful for debugging but clutter normal output. |
| 33 | + |
| 34 | +_mcp_lowlevel_logger = logging.getLogger("mcp.server.lowlevel.server") |
| 35 | +if LOG_LEVEL == "DEBUG": |
| 36 | + # In full debug mode, allow the library to emit its detailed logs |
| 37 | + _mcp_lowlevel_logger.setLevel(logging.DEBUG) |
| 38 | +else: |
| 39 | + # Otherwise, only warnings and above |
| 40 | + _mcp_lowlevel_logger.setLevel(logging.WARNING) |
| 41 | + |
| 42 | +# ---------------------------------------------------------------------------- |
| 43 | +# Environment configuration |
| 44 | +# ---------------------------------------------------------------------------- |
| 45 | +# Load and validate settings from the environment (uses .env automatically) |
| 46 | +settings = ProxySettings.load() |
| 47 | + |
| 48 | +# Upstream endpoints (fully-qualified URLs) |
| 49 | +UPSTREAM_AUTHORIZE: str = str(settings.upstream_authorize) |
| 50 | +UPSTREAM_TOKEN: str = str(settings.upstream_token) |
| 51 | +UPSTREAM_JWKS_URI = settings.jwks_uri |
| 52 | +# Derive base URL from the authorize endpoint for convenience / tests |
| 53 | +UPSTREAM_BASE: str = UPSTREAM_AUTHORIZE.rsplit("/", 1)[0] |
| 54 | + |
| 55 | +# Client credentials & defaults |
| 56 | +CLIENT_ID: str = settings.client_id or "demo-client-id" |
| 57 | +CLIENT_SECRET = settings.client_secret |
| 58 | +DEFAULT_SCOPE: str = settings.default_scope |
| 59 | + |
| 60 | +# Optional audience passthrough (not part of ProxySettings yet) |
| 61 | +AUDIENCE = os.getenv("PROXY_AUDIENCE") |
| 62 | + |
| 63 | +# Metadata URL (only used if we need to fetch from upstream) |
| 64 | +UPSTREAM_METADATA = f"{UPSTREAM_BASE}/.well-known/oauth-authorization-server" |
| 65 | + |
| 66 | +# --------------------------------------------------------------------------- |
| 67 | +# Logging helpers |
| 68 | +# --------------------------------------------------------------------------- |
| 69 | + |
| 70 | + |
| 71 | +def _mask_secret(secret: str | None) -> str | None: # noqa: D401 |
| 72 | + """Return a masked version of the given secret. |
| 73 | +
|
| 74 | + The first and last four characters are preserved (if available) and the |
| 75 | + middle section is replaced by asterisks. If the secret is shorter than |
| 76 | + eight characters, the entire value is replaced by ``*``. |
| 77 | + """ |
| 78 | + |
| 79 | + if not secret: |
| 80 | + return None |
| 81 | + |
| 82 | + if len(secret) <= 8: |
| 83 | + return "*" * len(secret) |
| 84 | + |
| 85 | + return f"{secret[:4]}{'*' * (len(secret) - 8)}{secret[-4:]}" |
| 86 | + |
| 87 | + |
| 88 | +# Consolidated configuration (with sensitive data redacted) |
| 89 | +_masked_settings = settings.model_dump(exclude_none=True).copy() |
| 90 | + |
| 91 | +if "client_secret" in _masked_settings: |
| 92 | + _masked_settings["client_secret"] = _mask_secret(_masked_settings["client_secret"]) |
| 93 | + |
| 94 | +# Log configuration at *debug* level only so it can be enabled when needed |
| 95 | +logger.debug("[Proxy Config] %s", _masked_settings) |
| 96 | + |
| 97 | +# Server host/port |
| 98 | +PROXY_PORT = int(os.getenv("PROXY_PORT", "8000")) |
| 99 | + |
| 100 | +# ---------------------------------------------------------------------------- |
| 101 | +# FastMCP server (now created via library helper) |
| 102 | +# ---------------------------------------------------------------------------- |
| 103 | + |
| 104 | +ISSUER_URL = os.getenv("PROXY_ISSUER_URL", "http://localhost:8000") |
| 105 | + |
| 106 | +# Create FastMCP instance using the reusable proxy builder |
| 107 | +mcp = build_proxy_server(port=PROXY_PORT, issuer_url=ISSUER_URL) |
| 108 | + |
| 109 | +# --------------------------------------------------------------------------- |
| 110 | +# Minimal demo tool |
| 111 | +# --------------------------------------------------------------------------- |
| 112 | + |
| 113 | + |
| 114 | +@mcp.tool() |
| 115 | +def echo(message: str) -> str: |
| 116 | + return f"Echo: {message}" |
| 117 | + |
| 118 | + |
| 119 | +@mcp.tool() |
| 120 | +async def user_info(ctx: Context[Any, Any, Request]) -> dict[str, Any]: |
| 121 | + """ |
| 122 | + Get information about the authenticated user. |
| 123 | +
|
| 124 | + This tool demonstrates accessing user information from the OAuth access token. |
| 125 | + The user must be authenticated via OAuth to access this tool. |
| 126 | +
|
| 127 | + Returns: |
| 128 | + Dictionary containing user information from the access token |
| 129 | + """ |
| 130 | + from mcp.server.auth.middleware.auth_context import get_access_token |
| 131 | + |
| 132 | + # Get the access token from the authentication context |
| 133 | + access_token = get_access_token() |
| 134 | + |
| 135 | + if not access_token: |
| 136 | + return { |
| 137 | + "error": "No access token found - user not authenticated", |
| 138 | + "authenticated": False, |
| 139 | + } |
| 140 | + |
| 141 | + # Attempt to decode the access token as JWT to extract useful user claims. |
| 142 | + # Many OAuth providers issue JWT access tokens (or ID tokens) that contain |
| 143 | + # the user's subject (sub) and preferred username. We parse the token |
| 144 | + # *without* signature verification – we only need the public claims for |
| 145 | + # display purposes. If the token is opaque or the decode fails, we simply |
| 146 | + # skip this step. |
| 147 | + |
| 148 | + def _try_decode_jwt(token_str: str) -> dict[str, Any] | None: # noqa: D401 |
| 149 | + """Best-effort JWT decode without verification. |
| 150 | +
|
| 151 | + Returns the payload dictionary if the token *looks* like a JWT and can |
| 152 | + be base64-decoded. If anything fails we return None. |
| 153 | + """ |
| 154 | + |
| 155 | + try: |
| 156 | + parts = token_str.split(".") |
| 157 | + if len(parts) != 3: |
| 158 | + return None # Not a JWT |
| 159 | + |
| 160 | + # JWT parts are URL-safe base64 without padding |
| 161 | + def _b64decode(segment: str) -> bytes: |
| 162 | + padding = "=" * (-len(segment) % 4) |
| 163 | + return base64.urlsafe_b64decode(segment + padding) |
| 164 | + |
| 165 | + payload_bytes = _b64decode(parts[1]) |
| 166 | + return json.loads(payload_bytes) |
| 167 | + except Exception: # noqa: BLE001 |
| 168 | + return None |
| 169 | + |
| 170 | + jwt_claims = _try_decode_jwt(access_token.token) |
| 171 | + |
| 172 | + # Build response with token information plus any extracted claims |
| 173 | + response: dict[str, Any] = { |
| 174 | + "authenticated": True, |
| 175 | + "client_id": access_token.client_id, |
| 176 | + "scopes": access_token.scopes, |
| 177 | + "token_type": "Bearer", |
| 178 | + "expires_at": access_token.expires_at, |
| 179 | + "resource": access_token.resource, |
| 180 | + } |
| 181 | + |
| 182 | + if jwt_claims: |
| 183 | + # Prefer the `userid` claim used in FastMCP examples; fall back to `sub` if |
| 184 | + # absent. |
| 185 | + uid = jwt_claims.get("userid") or jwt_claims.get("sub") |
| 186 | + if uid is not None: |
| 187 | + response["userid"] = uid # camelCase variant used in FastMCP reference |
| 188 | + response["user_id"] = uid # snake_case variant |
| 189 | + response["username"] = ( |
| 190 | + jwt_claims.get("preferred_username") |
| 191 | + or jwt_claims.get("nickname") |
| 192 | + or jwt_claims.get("name") |
| 193 | + ) |
| 194 | + response["issuer"] = jwt_claims.get("iss") |
| 195 | + response["audience"] = jwt_claims.get("aud") |
| 196 | + response["issued_at"] = jwt_claims.get("iat") |
| 197 | + |
| 198 | + # Calculate expiration helpers |
| 199 | + if access_token.expires_at: |
| 200 | + response["expires_at_iso"] = time.strftime( |
| 201 | + "%Y-%m-%dT%H:%M:%S", time.localtime(access_token.expires_at) |
| 202 | + ) |
| 203 | + response["expires_in_seconds"] = max( |
| 204 | + 0, access_token.expires_at - int(time.time()) |
| 205 | + ) |
| 206 | + |
| 207 | + return response |
| 208 | + |
| 209 | + |
| 210 | +@mcp.tool() |
| 211 | +async def test_endpoint(message: str = "Hello from proxy server!") -> dict[str, Any]: |
| 212 | + """ |
| 213 | + Test endpoint for debugging OAuth proxy functionality. |
| 214 | +
|
| 215 | + Args: |
| 216 | + message: Optional message to echo back |
| 217 | +
|
| 218 | + Returns: |
| 219 | + Test response with server information |
| 220 | + """ |
| 221 | + return { |
| 222 | + "message": message, |
| 223 | + "server": "Transparent OAuth Proxy Server", |
| 224 | + "status": "active", |
| 225 | + "oauth_configured": True, |
| 226 | + } |
| 227 | + |
| 228 | + |
| 229 | +if __name__ == "__main__": |
| 230 | + mcp.run(transport="streamable-http") |
0 commit comments