1
+ # pyright: reportMissingImports=false
2
+ import os
3
+ import logging
4
+ from dotenv import load_dotenv # type: ignore
5
+ from typing import Any , cast
6
+ import base64 , json , time
7
+ from starlette .requests import Request # type: ignore
8
+
9
+ from mcp .server .fastmcp .server import Context
10
+ from mcp .server .auth .proxy .server import build_proxy_server # noqa: E402
11
+ from mcp .server .auth .providers .transparent_proxy import ProxySettings # type: ignore
12
+
13
+ # Load environment variables from .env if present
14
+ load_dotenv ()
15
+
16
+ # Configure logging after .env so LOG_LEVEL can come from environment
17
+ LOG_LEVEL = os .getenv ("LOG_LEVEL" , "INFO" ).upper ()
18
+
19
+ logging .basicConfig (
20
+ level = LOG_LEVEL ,
21
+ format = "%(asctime)s [%(levelname)s] %(name)s: %(message)s" ,
22
+ datefmt = "%Y-%m-%d %H:%M:%S" ,
23
+ )
24
+
25
+ # Dedicated logger for this server module
26
+ logger = logging .getLogger ("proxy_oauth.server" )
27
+
28
+ # Suppress noisy INFO messages from the FastMCP low-level server unless we are
29
+ # explicitly running in DEBUG mode. These logs (e.g. "Processing request of type
30
+ # ListToolsRequest") are helpful for debugging but clutter normal output.
31
+
32
+ _mcp_lowlevel_logger = logging .getLogger ("mcp.server.lowlevel.server" )
33
+ if LOG_LEVEL == "DEBUG" :
34
+ # In full debug mode, allow the library to emit its detailed logs
35
+ _mcp_lowlevel_logger .setLevel (logging .DEBUG )
36
+ else :
37
+ # Otherwise, only warnings and above
38
+ _mcp_lowlevel_logger .setLevel (logging .WARNING )
39
+
40
+ # ----------------------------------------------------------------------------
41
+ # Environment configuration
42
+ # ----------------------------------------------------------------------------
43
+ # Load and validate settings from the environment (uses .env automatically)
44
+ settings = ProxySettings .load ()
45
+
46
+ # Upstream endpoints (fully-qualified URLs)
47
+ UPSTREAM_AUTHORIZE : str = str (settings .upstream_authorize )
48
+ UPSTREAM_TOKEN : str = str (settings .upstream_token )
49
+ UPSTREAM_JWKS_URI = settings .jwks_uri
50
+ # Derive base URL from the authorize endpoint for convenience / tests
51
+ UPSTREAM_BASE : str = UPSTREAM_AUTHORIZE .rsplit ("/" , 1 )[0 ]
52
+
53
+ # Client credentials & defaults
54
+ CLIENT_ID : str = settings .client_id or "demo-client-id"
55
+ CLIENT_SECRET = settings .client_secret
56
+ DEFAULT_SCOPE : str = settings .default_scope
57
+
58
+ # Optional audience passthrough (not part of ProxySettings yet)
59
+ AUDIENCE = os .getenv ("PROXY_AUDIENCE" )
60
+
61
+ # Metadata URL (only used if we need to fetch from upstream)
62
+ UPSTREAM_METADATA = f"{ UPSTREAM_BASE } /.well-known/oauth-authorization-server"
63
+
64
+ # ---------------------------------------------------------------------------
65
+ # Logging helpers
66
+ # ---------------------------------------------------------------------------
67
+
68
+ def _mask_secret (secret : str | None ) -> str | None : # noqa: D401
69
+ """Return a masked version of the given secret.
70
+
71
+ The first and last four characters are preserved (if available) and the
72
+ middle section is replaced by asterisks. If the secret is shorter than
73
+ eight characters, the entire value is replaced by ``*``.
74
+ """
75
+
76
+ if not secret :
77
+ return None
78
+
79
+ if len (secret ) <= 8 :
80
+ return "*" * len (secret )
81
+
82
+ return f"{ secret [:4 ]} { '*' * (len (secret ) - 8 )} { secret [- 4 :]} "
83
+
84
+ # Consolidated configuration (with sensitive data redacted)
85
+ _masked_settings = settings .model_dump (exclude_none = True ).copy ()
86
+
87
+ if "client_secret" in _masked_settings :
88
+ _masked_settings ["client_secret" ] = _mask_secret (_masked_settings ["client_secret" ])
89
+
90
+ # Log configuration at *debug* level only so it can be enabled when needed
91
+ logger .debug ("[Proxy Config] %s" , _masked_settings )
92
+
93
+ # Server host/port
94
+ PROXY_PORT = int (os .getenv ("PROXY_PORT" , "8000" ))
95
+
96
+ # ----------------------------------------------------------------------------
97
+ # FastMCP server (now created via library helper)
98
+ # ----------------------------------------------------------------------------
99
+
100
+ ISSUER_URL = os .getenv ("PROXY_ISSUER_URL" , "http://localhost:8000" )
101
+
102
+ # Create FastMCP instance using the reusable proxy builder
103
+ mcp = build_proxy_server (port = PROXY_PORT , issuer_url = ISSUER_URL )
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Minimal demo tool
107
+ # ---------------------------------------------------------------------------
108
+
109
+ @mcp .tool ()
110
+ def echo (message : str ) -> str :
111
+ return f"Echo: { message } "
112
+
113
+
114
+ @mcp .tool ()
115
+ async def user_info (ctx : Context [Any , Any , Request ]) -> dict [str , Any ]:
116
+ """
117
+ Get information about the authenticated user.
118
+
119
+ This tool demonstrates accessing user information from the OAuth access token.
120
+ The user must be authenticated via OAuth to access this tool.
121
+
122
+ Returns:
123
+ Dictionary containing user information from the access token
124
+ """
125
+ from mcp .server .auth .middleware .auth_context import get_access_token
126
+
127
+ # Get the access token from the authentication context
128
+ access_token = get_access_token ()
129
+
130
+ if not access_token :
131
+ return {
132
+ "error" : "No access token found - user not authenticated" ,
133
+ "authenticated" : False
134
+ }
135
+
136
+ # Attempt to decode the access token as JWT to extract useful user claims.
137
+ # Many OAuth providers issue JWT access tokens (or ID tokens) that contain
138
+ # the user's subject (sub) and preferred username. We parse the token
139
+ # *without* signature verification – we only need the public claims for
140
+ # display purposes. If the token is opaque or the decode fails, we simply
141
+ # skip this step.
142
+
143
+ def _try_decode_jwt (token_str : str ) -> dict [str , Any ] | None : # noqa: D401
144
+ """Best-effort JWT decode without verification.
145
+
146
+ Returns the payload dictionary if the token *looks* like a JWT and can
147
+ be base64-decoded. If anything fails we return None.
148
+ """
149
+
150
+ try :
151
+ parts = token_str .split ("." )
152
+ if len (parts ) != 3 :
153
+ return None # Not a JWT
154
+
155
+ # JWT parts are URL-safe base64 without padding
156
+ def _b64decode (segment : str ) -> bytes :
157
+ padding = "=" * (- len (segment ) % 4 )
158
+ return base64 .urlsafe_b64decode (segment + padding )
159
+
160
+ payload_bytes = _b64decode (parts [1 ])
161
+ return json .loads (payload_bytes )
162
+ except Exception : # noqa: BLE001
163
+ return None
164
+
165
+ jwt_claims = _try_decode_jwt (access_token .token )
166
+
167
+ # Build response with token information plus any extracted claims
168
+ response : dict [str , Any ] = {
169
+ "authenticated" : True ,
170
+ "client_id" : access_token .client_id ,
171
+ "scopes" : access_token .scopes ,
172
+ "token_type" : "Bearer" ,
173
+ "expires_at" : access_token .expires_at ,
174
+ "resource" : access_token .resource ,
175
+ }
176
+
177
+ if jwt_claims :
178
+ # Prefer the `userid` claim used in FastMCP examples; fall back to `sub` if absent.
179
+ uid = jwt_claims .get ("userid" ) or jwt_claims .get ("sub" )
180
+ if uid is not None :
181
+ response ["userid" ] = uid # camelCase variant used in FastMCP reference
182
+ response ["user_id" ] = uid # snake_case variant
183
+ response ["username" ] = (
184
+ jwt_claims .get ("preferred_username" )
185
+ or jwt_claims .get ("nickname" )
186
+ or jwt_claims .get ("name" )
187
+ )
188
+ response ["issuer" ] = jwt_claims .get ("iss" )
189
+ response ["audience" ] = jwt_claims .get ("aud" )
190
+ response ["issued_at" ] = jwt_claims .get ("iat" )
191
+
192
+ # Calculate expiration helpers
193
+ if access_token .expires_at :
194
+ response ["expires_at_iso" ] = time .strftime ('%Y-%m-%dT%H:%M:%S' , time .localtime (access_token .expires_at ))
195
+ response ["expires_in_seconds" ] = max (0 , access_token .expires_at - int (time .time ()))
196
+
197
+ return response
198
+
199
+
200
+ @mcp .tool ()
201
+ async def test_endpoint (message : str = "Hello from proxy server!" ) -> dict [str , Any ]:
202
+ """
203
+ Test endpoint for debugging OAuth proxy functionality.
204
+
205
+ Args:
206
+ message: Optional message to echo back
207
+
208
+ Returns:
209
+ Test response with server information
210
+ """
211
+ return {
212
+ "message" : message ,
213
+ "server" : "Transparent OAuth Proxy Server" ,
214
+ "status" : "active" ,
215
+ "oauth_configured" : True
216
+ }
217
+
218
+
219
+ if __name__ == "__main__" :
220
+ mcp .run (transport = "streamable-http" )
0 commit comments