Skip to content

Commit 75426ff

Browse files
committed
Add example proxy OAuth server implementation
Signed-off-by: Jesse Sanford <108698+jessesanford@users.noreply.github.com>
1 parent bb911bb commit 75426ff

File tree

5 files changed

+371
-0
lines changed

5 files changed

+371
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# OAuth Proxy Server
2+
3+
This is a minimal OAuth proxy server example for the MCP Python SDK.
4+
5+
## Installation
6+
7+
```bash
8+
uv add proxy_oauth
9+
```
10+
11+
## Usage
12+
13+
This is a placeholder for the OAuth proxy server implementation.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""OAuth Proxy Server for MCP."""
2+
3+
__version__ = "0.1.0"
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
[project]
2+
name = "proxy_oauth"
3+
version = "0.1.0"
4+
description = "OAuth Proxy Server"
5+
authors = [{ name = "Your Name" }]
6+
readme = "README.md"
7+
requires-python = ">=3.10"
8+
dependencies = [
9+
"mcp",
10+
]
11+
12+
[project.optional-dependencies]
13+
dev = [
14+
"pytest>=6.0",
15+
]
16+
17+
[build-system]
18+
requires = ["hatchling"]
19+
build-backend = "hatchling.build"
20+
21+
[tool.hatch.build.targets.wheel]
22+
packages = ["proxy_oauth"]
23+
24+
[tool.pyright]
25+
include = ["proxy_oauth"]
26+
venvPath = "."
27+
venv = ".venv"
28+
29+
[tool.ruff.lint]
30+
select = ["E", "F", "I"]
31+
ignore = []
32+
33+
[tool.ruff]
34+
line-length = 88
35+
target-version = "py311"
36+
37+
[tool.uv]
38+
dev-dependencies = ["pyright>=1.1.391", "pytest>=8.3.4", "ruff>=0.8.5"]
39+
extras = ["dev"]
40+
41+
[[tool.uv.index]]
42+
url = "https://pypi.org/simple"
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# pyright: reportMissingImports=false
2+
import base64
3+
import json
4+
import logging
5+
import os
6+
import time
7+
from typing import Any
8+
9+
from dotenv import load_dotenv # type: ignore
10+
from mcp.server.auth.providers.transparent_proxy import ProxySettings # type: ignore
11+
from mcp.server.auth.proxy.server import build_proxy_server # noqa: E402
12+
from mcp.server.fastmcp.server import Context
13+
from starlette.requests import Request # type: ignore
14+
15+
# Load environment variables from .env if present
16+
load_dotenv()
17+
18+
# Configure logging after .env so LOG_LEVEL can come from environment
19+
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
20+
21+
logging.basicConfig(
22+
level=LOG_LEVEL,
23+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
24+
datefmt="%Y-%m-%d %H:%M:%S",
25+
)
26+
27+
# Dedicated logger for this server module
28+
logger = logging.getLogger("proxy_oauth.server")
29+
30+
# Suppress noisy INFO messages from the FastMCP low-level server unless we are
31+
# explicitly running in DEBUG mode. These logs (e.g. "Processing request of type
32+
# ListToolsRequest") are helpful for debugging but clutter normal output.
33+
34+
_mcp_lowlevel_logger = logging.getLogger("mcp.server.lowlevel.server")
35+
if LOG_LEVEL == "DEBUG":
36+
# In full debug mode, allow the library to emit its detailed logs
37+
_mcp_lowlevel_logger.setLevel(logging.DEBUG)
38+
else:
39+
# Otherwise, only warnings and above
40+
_mcp_lowlevel_logger.setLevel(logging.WARNING)
41+
42+
# ----------------------------------------------------------------------------
43+
# Environment configuration
44+
# ----------------------------------------------------------------------------
45+
# Load and validate settings from the environment (uses .env automatically)
46+
settings = ProxySettings.load()
47+
48+
# Upstream endpoints (fully-qualified URLs)
49+
UPSTREAM_AUTHORIZE: str = str(settings.upstream_authorize)
50+
UPSTREAM_TOKEN: str = str(settings.upstream_token)
51+
UPSTREAM_JWKS_URI = settings.jwks_uri
52+
# Derive base URL from the authorize endpoint for convenience / tests
53+
UPSTREAM_BASE: str = UPSTREAM_AUTHORIZE.rsplit("/", 1)[0]
54+
55+
# Client credentials & defaults
56+
CLIENT_ID: str = settings.client_id or "demo-client-id"
57+
CLIENT_SECRET = settings.client_secret
58+
DEFAULT_SCOPE: str = settings.default_scope
59+
60+
# Optional audience passthrough (not part of ProxySettings yet)
61+
AUDIENCE = os.getenv("PROXY_AUDIENCE")
62+
63+
# Metadata URL (only used if we need to fetch from upstream)
64+
UPSTREAM_METADATA = f"{UPSTREAM_BASE}/.well-known/oauth-authorization-server"
65+
66+
# ---------------------------------------------------------------------------
67+
# Logging helpers
68+
# ---------------------------------------------------------------------------
69+
70+
71+
def _mask_secret(secret: str | None) -> str | None: # noqa: D401
72+
"""Return a masked version of the given secret.
73+
74+
The first and last four characters are preserved (if available) and the
75+
middle section is replaced by asterisks. If the secret is shorter than
76+
eight characters, the entire value is replaced by ``*``.
77+
"""
78+
79+
if not secret:
80+
return None
81+
82+
if len(secret) <= 8:
83+
return "*" * len(secret)
84+
85+
return f"{secret[:4]}{'*' * (len(secret) - 8)}{secret[-4:]}"
86+
87+
88+
# Consolidated configuration (with sensitive data redacted)
89+
_masked_settings = settings.model_dump(exclude_none=True).copy()
90+
91+
if "client_secret" in _masked_settings:
92+
_masked_settings["client_secret"] = _mask_secret(_masked_settings["client_secret"])
93+
94+
# Log configuration at *debug* level only so it can be enabled when needed
95+
logger.debug("[Proxy Config] %s", _masked_settings)
96+
97+
# Server host/port
98+
PROXY_PORT = int(os.getenv("PROXY_PORT", "8000"))
99+
100+
# ----------------------------------------------------------------------------
101+
# FastMCP server (now created via library helper)
102+
# ----------------------------------------------------------------------------
103+
104+
ISSUER_URL = os.getenv("PROXY_ISSUER_URL", "http://localhost:8000")
105+
106+
# Create FastMCP instance using the reusable proxy builder
107+
mcp = build_proxy_server(port=PROXY_PORT, issuer_url=ISSUER_URL)
108+
109+
# ---------------------------------------------------------------------------
110+
# Minimal demo tool
111+
# ---------------------------------------------------------------------------
112+
113+
114+
@mcp.tool()
115+
def echo(message: str) -> str:
116+
return f"Echo: {message}"
117+
118+
119+
@mcp.tool()
120+
async def user_info(ctx: Context[Any, Any, Request]) -> dict[str, Any]:
121+
"""
122+
Get information about the authenticated user.
123+
124+
This tool demonstrates accessing user information from the OAuth access token.
125+
The user must be authenticated via OAuth to access this tool.
126+
127+
Returns:
128+
Dictionary containing user information from the access token
129+
"""
130+
from mcp.server.auth.middleware.auth_context import get_access_token
131+
132+
# Get the access token from the authentication context
133+
access_token = get_access_token()
134+
135+
if not access_token:
136+
return {
137+
"error": "No access token found - user not authenticated",
138+
"authenticated": False,
139+
}
140+
141+
# Attempt to decode the access token as JWT to extract useful user claims.
142+
# Many OAuth providers issue JWT access tokens (or ID tokens) that contain
143+
# the user's subject (sub) and preferred username. We parse the token
144+
# *without* signature verification – we only need the public claims for
145+
# display purposes. If the token is opaque or the decode fails, we simply
146+
# skip this step.
147+
148+
def _try_decode_jwt(token_str: str) -> dict[str, Any] | None: # noqa: D401
149+
"""Best-effort JWT decode without verification.
150+
151+
Returns the payload dictionary if the token *looks* like a JWT and can
152+
be base64-decoded. If anything fails we return None.
153+
"""
154+
155+
try:
156+
parts = token_str.split(".")
157+
if len(parts) != 3:
158+
return None # Not a JWT
159+
160+
# JWT parts are URL-safe base64 without padding
161+
def _b64decode(segment: str) -> bytes:
162+
padding = "=" * (-len(segment) % 4)
163+
return base64.urlsafe_b64decode(segment + padding)
164+
165+
payload_bytes = _b64decode(parts[1])
166+
return json.loads(payload_bytes)
167+
except Exception: # noqa: BLE001
168+
return None
169+
170+
jwt_claims = _try_decode_jwt(access_token.token)
171+
172+
# Build response with token information plus any extracted claims
173+
response: dict[str, Any] = {
174+
"authenticated": True,
175+
"client_id": access_token.client_id,
176+
"scopes": access_token.scopes,
177+
"token_type": "Bearer",
178+
"expires_at": access_token.expires_at,
179+
"resource": access_token.resource,
180+
}
181+
182+
if jwt_claims:
183+
# Prefer the `userid` claim used in FastMCP examples; fall back to `sub` if
184+
# absent.
185+
uid = jwt_claims.get("userid") or jwt_claims.get("sub")
186+
if uid is not None:
187+
response["userid"] = uid # camelCase variant used in FastMCP reference
188+
response["user_id"] = uid # snake_case variant
189+
response["username"] = (
190+
jwt_claims.get("preferred_username")
191+
or jwt_claims.get("nickname")
192+
or jwt_claims.get("name")
193+
)
194+
response["issuer"] = jwt_claims.get("iss")
195+
response["audience"] = jwt_claims.get("aud")
196+
response["issued_at"] = jwt_claims.get("iat")
197+
198+
# Calculate expiration helpers
199+
if access_token.expires_at:
200+
response["expires_at_iso"] = time.strftime(
201+
"%Y-%m-%dT%H:%M:%S", time.localtime(access_token.expires_at)
202+
)
203+
response["expires_in_seconds"] = max(
204+
0, access_token.expires_at - int(time.time())
205+
)
206+
207+
return response
208+
209+
210+
@mcp.tool()
211+
async def test_endpoint(message: str = "Hello from proxy server!") -> dict[str, Any]:
212+
"""
213+
Test endpoint for debugging OAuth proxy functionality.
214+
215+
Args:
216+
message: Optional message to echo back
217+
218+
Returns:
219+
Test response with server information
220+
"""
221+
return {
222+
"message": message,
223+
"server": "Transparent OAuth Proxy Server",
224+
"status": "active",
225+
"oauth_configured": True,
226+
}
227+
228+
229+
if __name__ == "__main__":
230+
mcp.run(transport="streamable-http")

examples/servers/proxy_oauth/uv.lock

Lines changed: 83 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)