Skip to content

Commit df9ad3f

Browse files
committed
enables oauth proxy capability
Signed-off-by: Jesse Sanford <108698+jessesanford@users.noreply.github.com>
1 parent 6f43d1f commit df9ad3f

File tree

12 files changed

+1388
-23
lines changed

12 files changed

+1388
-23
lines changed
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# pyright: reportMissingImports=false
2+
import os
3+
import logging
4+
from dotenv import load_dotenv # type: ignore
5+
from typing import Any, cast
6+
import base64, json, time
7+
from starlette.requests import Request # type: ignore
8+
9+
from mcp.server.fastmcp.server import Context
10+
from mcp.server.auth.proxy.server import build_proxy_server # noqa: E402
11+
from mcp.server.auth.providers.transparent_proxy import ProxySettings # type: ignore
12+
13+
# Load environment variables from .env if present
14+
load_dotenv()
15+
16+
# Configure logging after .env so LOG_LEVEL can come from environment
17+
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
18+
19+
logging.basicConfig(
20+
level=LOG_LEVEL,
21+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
22+
datefmt="%Y-%m-%d %H:%M:%S",
23+
)
24+
25+
# Dedicated logger for this server module
26+
logger = logging.getLogger("proxy_oauth.server")
27+
28+
# Suppress noisy INFO messages from the FastMCP low-level server unless we are
29+
# explicitly running in DEBUG mode. These logs (e.g. "Processing request of type
30+
# ListToolsRequest") are helpful for debugging but clutter normal output.
31+
32+
_mcp_lowlevel_logger = logging.getLogger("mcp.server.lowlevel.server")
33+
if LOG_LEVEL == "DEBUG":
34+
# In full debug mode, allow the library to emit its detailed logs
35+
_mcp_lowlevel_logger.setLevel(logging.DEBUG)
36+
else:
37+
# Otherwise, only warnings and above
38+
_mcp_lowlevel_logger.setLevel(logging.WARNING)
39+
40+
# ----------------------------------------------------------------------------
41+
# Environment configuration
42+
# ----------------------------------------------------------------------------
43+
# Load and validate settings from the environment (uses .env automatically)
44+
settings = ProxySettings.load()
45+
46+
# Upstream endpoints (fully-qualified URLs)
47+
UPSTREAM_AUTHORIZE: str = str(settings.upstream_authorize)
48+
UPSTREAM_TOKEN: str = str(settings.upstream_token)
49+
UPSTREAM_JWKS_URI = settings.jwks_uri
50+
# Derive base URL from the authorize endpoint for convenience / tests
51+
UPSTREAM_BASE: str = UPSTREAM_AUTHORIZE.rsplit("/", 1)[0]
52+
53+
# Client credentials & defaults
54+
CLIENT_ID: str = settings.client_id or "demo-client-id"
55+
CLIENT_SECRET = settings.client_secret
56+
DEFAULT_SCOPE: str = settings.default_scope
57+
58+
# Optional audience passthrough (not part of ProxySettings yet)
59+
AUDIENCE = os.getenv("PROXY_AUDIENCE")
60+
61+
# Metadata URL (only used if we need to fetch from upstream)
62+
UPSTREAM_METADATA = f"{UPSTREAM_BASE}/.well-known/oauth-authorization-server"
63+
64+
# ---------------------------------------------------------------------------
65+
# Logging helpers
66+
# ---------------------------------------------------------------------------
67+
68+
def _mask_secret(secret: str | None) -> str | None: # noqa: D401
69+
"""Return a masked version of the given secret.
70+
71+
The first and last four characters are preserved (if available) and the
72+
middle section is replaced by asterisks. If the secret is shorter than
73+
eight characters, the entire value is replaced by ``*``.
74+
"""
75+
76+
if not secret:
77+
return None
78+
79+
if len(secret) <= 8:
80+
return "*" * len(secret)
81+
82+
return f"{secret[:4]}{'*' * (len(secret) - 8)}{secret[-4:]}"
83+
84+
# Consolidated configuration (with sensitive data redacted)
85+
_masked_settings = settings.model_dump(exclude_none=True).copy()
86+
87+
if "client_secret" in _masked_settings:
88+
_masked_settings["client_secret"] = _mask_secret(_masked_settings["client_secret"])
89+
90+
# Log configuration at *debug* level only so it can be enabled when needed
91+
logger.debug("[Proxy Config] %s", _masked_settings)
92+
93+
# Server host/port
94+
PROXY_PORT = int(os.getenv("PROXY_PORT", "8000"))
95+
96+
# ----------------------------------------------------------------------------
97+
# FastMCP server (now created via library helper)
98+
# ----------------------------------------------------------------------------
99+
100+
ISSUER_URL = os.getenv("PROXY_ISSUER_URL", "http://localhost:8000")
101+
102+
# Create FastMCP instance using the reusable proxy builder
103+
mcp = build_proxy_server(port=PROXY_PORT, issuer_url=ISSUER_URL)
104+
105+
# ---------------------------------------------------------------------------
106+
# Minimal demo tool
107+
# ---------------------------------------------------------------------------
108+
109+
@mcp.tool()
110+
def echo(message: str) -> str:
111+
return f"Echo: {message}"
112+
113+
114+
@mcp.tool()
115+
async def user_info(ctx: Context[Any, Any, Request]) -> dict[str, Any]:
116+
"""
117+
Get information about the authenticated user.
118+
119+
This tool demonstrates accessing user information from the OAuth access token.
120+
The user must be authenticated via OAuth to access this tool.
121+
122+
Returns:
123+
Dictionary containing user information from the access token
124+
"""
125+
from mcp.server.auth.middleware.auth_context import get_access_token
126+
127+
# Get the access token from the authentication context
128+
access_token = get_access_token()
129+
130+
if not access_token:
131+
return {
132+
"error": "No access token found - user not authenticated",
133+
"authenticated": False
134+
}
135+
136+
# Attempt to decode the access token as JWT to extract useful user claims.
137+
# Many OAuth providers issue JWT access tokens (or ID tokens) that contain
138+
# the user's subject (sub) and preferred username. We parse the token
139+
# *without* signature verification – we only need the public claims for
140+
# display purposes. If the token is opaque or the decode fails, we simply
141+
# skip this step.
142+
143+
def _try_decode_jwt(token_str: str) -> dict[str, Any] | None: # noqa: D401
144+
"""Best-effort JWT decode without verification.
145+
146+
Returns the payload dictionary if the token *looks* like a JWT and can
147+
be base64-decoded. If anything fails we return None.
148+
"""
149+
150+
try:
151+
parts = token_str.split(".")
152+
if len(parts) != 3:
153+
return None # Not a JWT
154+
155+
# JWT parts are URL-safe base64 without padding
156+
def _b64decode(segment: str) -> bytes:
157+
padding = "=" * (-len(segment) % 4)
158+
return base64.urlsafe_b64decode(segment + padding)
159+
160+
payload_bytes = _b64decode(parts[1])
161+
return json.loads(payload_bytes)
162+
except Exception: # noqa: BLE001
163+
return None
164+
165+
jwt_claims = _try_decode_jwt(access_token.token)
166+
167+
# Build response with token information plus any extracted claims
168+
response: dict[str, Any] = {
169+
"authenticated": True,
170+
"client_id": access_token.client_id,
171+
"scopes": access_token.scopes,
172+
"token_type": "Bearer",
173+
"expires_at": access_token.expires_at,
174+
"resource": access_token.resource,
175+
}
176+
177+
if jwt_claims:
178+
# Prefer the `userid` claim used in FastMCP examples; fall back to `sub` if absent.
179+
uid = jwt_claims.get("userid") or jwt_claims.get("sub")
180+
if uid is not None:
181+
response["userid"] = uid # camelCase variant used in FastMCP reference
182+
response["user_id"] = uid # snake_case variant
183+
response["username"] = (
184+
jwt_claims.get("preferred_username")
185+
or jwt_claims.get("nickname")
186+
or jwt_claims.get("name")
187+
)
188+
response["issuer"] = jwt_claims.get("iss")
189+
response["audience"] = jwt_claims.get("aud")
190+
response["issued_at"] = jwt_claims.get("iat")
191+
192+
# Calculate expiration helpers
193+
if access_token.expires_at:
194+
response["expires_at_iso"] = time.strftime('%Y-%m-%dT%H:%M:%S', time.localtime(access_token.expires_at))
195+
response["expires_in_seconds"] = max(0, access_token.expires_at - int(time.time()))
196+
197+
return response
198+
199+
200+
@mcp.tool()
201+
async def test_endpoint(message: str = "Hello from proxy server!") -> dict[str, Any]:
202+
"""
203+
Test endpoint for debugging OAuth proxy functionality.
204+
205+
Args:
206+
message: Optional message to echo back
207+
208+
Returns:
209+
Test response with server information
210+
"""
211+
return {
212+
"message": message,
213+
"server": "Transparent OAuth Proxy Server",
214+
"status": "active",
215+
"oauth_configured": True
216+
}
217+
218+
219+
if __name__ == "__main__":
220+
mcp.run(transport="streamable-http")
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Test script to demonstrate the enhanced logging features of the transparent OAuth proxy.
4+
This script makes requests to various endpoints to show the comprehensive logging.
5+
"""
6+
7+
import time
8+
import requests
9+
import json
10+
import sys
11+
from threading import Thread
12+
import subprocess
13+
import os
14+
import signal
15+
16+
def test_endpoints():
17+
"""Test various endpoints to demonstrate logging."""
18+
base_url = "http://localhost:8000"
19+
20+
print("\n" + "="*60)
21+
print("🧪 TESTING ENHANCED LOGGING FEATURES")
22+
print("="*60)
23+
24+
# Wait for server to be ready
25+
print("⏳ Waiting for server to be ready...")
26+
time.sleep(3)
27+
28+
try:
29+
# Test 1: Metadata discovery
30+
print("\n🔍 Testing OAuth metadata discovery...")
31+
response = requests.get(f"{base_url}/.well-known/oauth-authorization-server",
32+
headers={"Host": "localhost:8000"})
33+
print(f" Status: {response.status_code}")
34+
35+
# Test 2: Client registration (DCR)
36+
print("\n📝 Testing Dynamic Client Registration...")
37+
registration_data = {
38+
"redirect_uris": ["http://localhost:3000/callback"],
39+
"grant_types": ["authorization_code"],
40+
"response_types": ["code"],
41+
"client_name": "Test MCP Client"
42+
}
43+
response = requests.post(f"{base_url}/register",
44+
json=registration_data,
45+
headers={"Content-Type": "application/json"})
46+
print(f" Status: {response.status_code}")
47+
48+
# Test 3: Authorization endpoint
49+
print("\n🔐 Testing authorization endpoint...")
50+
auth_params = {
51+
"response_type": "code",
52+
"client_id": "test-client-id",
53+
"redirect_uri": "http://localhost:3000/callback",
54+
"scope": "openid profile email",
55+
"state": "test-state-12345",
56+
"code_challenge": "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
57+
"code_challenge_method": "S256"
58+
}
59+
response = requests.get(f"{base_url}/authorize",
60+
params=auth_params,
61+
allow_redirects=False)
62+
print(f" Status: {response.status_code}")
63+
64+
# Test 4: Token endpoint (this will show the most comprehensive logging)
65+
print("\n🎫 Testing token endpoint (will fail but show logging)...")
66+
token_data = {
67+
"grant_type": "authorization_code",
68+
"client_id": "test-client-id",
69+
"client_secret": "test-client-secret",
70+
"code": "test_auth_code_abcdef123456",
71+
"redirect_uri": "http://localhost:3000/callback",
72+
"code_verifier": "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
73+
}
74+
response = requests.post(f"{base_url}/token",
75+
data=token_data,
76+
headers={"Content-Type": "application/x-www-form-urlencoded"})
77+
print(f" Status: {response.status_code}")
78+
79+
# Test 5: Another token request with different parameters
80+
print("\n🎫 Testing token endpoint with different parameters...")
81+
token_data2 = {
82+
"grant_type": "authorization_code",
83+
"client_id": "test-client-id",
84+
"code": "different_test_code_xyz789",
85+
"redirect_uri": "http://localhost:3000/callback"
86+
}
87+
response = requests.post(f"{base_url}/token",
88+
data=token_data2,
89+
headers={"Content-Type": "application/x-www-form-urlencoded"})
90+
print(f" Status: {response.status_code}")
91+
92+
print("\n✅ Test requests completed!")
93+
print("🔍 Check the server logs above to see the enhanced logging with:")
94+
print(" • Correlation IDs for request tracing")
95+
print(" • Detailed request/response headers and data")
96+
print(" • Timing information")
97+
print(" • Emoji indicators for easy scanning")
98+
print(" • Sensitive data redaction")
99+
100+
except requests.exceptions.ConnectionError:
101+
print("❌ Could not connect to server. Make sure it's running on port 8000.")
102+
except Exception as e:
103+
print(f"❌ Error during testing: {e}")
104+
105+
if __name__ == "__main__":
106+
test_endpoints()

src/mcp/server/auth/__init__.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,45 @@
1+
# pyright: reportGeneralTypeIssues=false
12
"""
23
MCP OAuth server authorization components.
34
"""
5+
6+
# Convenience re-exports so users can simply::
7+
#
8+
# from mcp.server.auth import build_proxy_server
9+
#
10+
# instead of digging into the sub-package path.
11+
12+
from typing import TYPE_CHECKING
13+
14+
from mcp.server.auth.proxy import (
15+
configure_colored_logging,
16+
create_proxy_routes,
17+
fetch_upstream_metadata,
18+
)
19+
20+
# For *build_proxy_server* we need a lazy import to avoid a circular reference
21+
# during the initial package import sequence (FastMCP -> auth -> proxy ->
22+
# FastMCP ...). PEP 562 allows us to implement module-level `__getattr__` for
23+
# this purpose.
24+
25+
def __getattr__(name: str): # noqa: D401
26+
if name == "build_proxy_server":
27+
from mcp.server.auth.proxy.server import build_proxy_server as _bps # noqa: WPS433
28+
29+
globals()["build_proxy_server"] = _bps
30+
return _bps
31+
raise AttributeError(name)
32+
33+
# ---------------------------------------------------------------------------
34+
# Public API specification
35+
# ---------------------------------------------------------------------------
36+
37+
__all__: list[str] = [
38+
"configure_colored_logging",
39+
"create_proxy_routes",
40+
"fetch_upstream_metadata",
41+
"build_proxy_server",
42+
]
43+
44+
if TYPE_CHECKING: # pragma: no cover – make *build_proxy_server* visible to type checkers
45+
from mcp.server.auth.proxy.server import build_proxy_server # noqa: F401

0 commit comments

Comments
 (0)