Skip to content

Commit 38cd9e9

Browse files
committed
Add tests for proxy builder and OAuth endpoints
Signed-off-by: Jesse Sanford <108698+jessesanford@users.noreply.github.com>
1 parent 75426ff commit 38cd9e9

File tree

2 files changed

+269
-0
lines changed

2 files changed

+269
-0
lines changed

tests/test_proxy_builder.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# pyright: reportMissingImports=false, reportGeneralTypeIssues=false
2+
"""Tests for the build_proxy_server convenience helper."""
3+
4+
from __future__ import annotations
5+
6+
from typing import cast
7+
8+
import httpx # type: ignore
9+
import pytest # type: ignore
10+
from pydantic import AnyHttpUrl
11+
12+
from mcp.server.auth.providers.transparent_proxy import _Settings as ProxySettings
13+
from mcp.server.auth.proxy import routes as proxy_routes
14+
from mcp.server.auth.proxy.server import build_proxy_server
15+
16+
17+
@pytest.mark.anyio
18+
async def test_build_proxy_server_metadata(monkeypatch):
19+
"""Ensure the server starts and serves metadata without touching network."""
20+
21+
# Patch metadata fetcher so no real HTTP traffic occurs
22+
async def _fake_metadata(): # noqa: D401
23+
return {
24+
"issuer": "https://proxy.test",
25+
"authorization_endpoint": "https://proxy.test/authorize",
26+
"token_endpoint": "https://proxy.test/token",
27+
"registration_endpoint": "/register",
28+
}
29+
30+
monkeypatch.setattr(proxy_routes, "fetch_upstream_metadata", _fake_metadata, raising=True)
31+
32+
# Provide required upstream endpoints via settings object
33+
settings = ProxySettings( # type: ignore[call-arg]
34+
UPSTREAM_AUTHORIZATION_ENDPOINT=cast(AnyHttpUrl, "https://upstream.example.com/authorize"),
35+
UPSTREAM_TOKEN_ENDPOINT=cast(AnyHttpUrl, "https://upstream.example.com/token"),
36+
UPSTREAM_CLIENT_ID="demo-client-id",
37+
UPSTREAM_CLIENT_SECRET=None,
38+
UPSTREAM_JWKS_URI=None,
39+
)
40+
41+
mcp = build_proxy_server(port=0, settings=settings)
42+
43+
app = mcp.streamable_http_app()
44+
45+
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as c:
46+
r = await c.get("/.well-known/oauth-authorization-server")
47+
assert r.status_code == 200
48+
data = r.json()
49+
assert data["authorization_endpoint"].endswith("/authorize")

tests/test_proxy_oauth_endpoints.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# pyright: reportMissingImports=false
2+
# pytest test suite for examples/servers/proxy_oauth/server.py
3+
# These tests spin up the FastMCP Starlette application in-process and
4+
# exercise the custom HTTP routes as well as the `user_info` tool.
5+
6+
from __future__ import annotations
7+
8+
import base64
9+
import json
10+
import urllib.parse
11+
from collections.abc import AsyncGenerator
12+
from typing import Any
13+
14+
import httpx # type: ignore
15+
import pytest # type: ignore
16+
17+
18+
@pytest.fixture
19+
def proxy_server(monkeypatch):
20+
"""Import the proxy OAuth demo server with safe environment + stubs."""
21+
import os
22+
23+
# Avoid real outbound calls by pretending the upstream endpoints were
24+
# supplied explicitly via env vars – this makes `fetch_upstream_metadata`
25+
# construct metadata locally instead of performing an HTTP GET.
26+
os.environ.setdefault("UPSTREAM_AUTHORIZATION_ENDPOINT", "https://upstream.example.com/authorize")
27+
os.environ.setdefault("UPSTREAM_TOKEN_ENDPOINT", "https://upstream.example.com/token")
28+
os.environ.setdefault("UPSTREAM_JWKS_URI", "https://upstream.example.com/jwks")
29+
os.environ.setdefault("UPSTREAM_CLIENT_ID", "client123")
30+
os.environ.setdefault("UPSTREAM_CLIENT_SECRET", "secret123")
31+
32+
# Deferred import so the env vars above are in effect.
33+
from examples.servers.proxy_oauth import server as proxy_server_module
34+
35+
# Stub library-level fetch_upstream_metadata to avoid network I/O.
36+
from mcp.server.auth.proxy import routes as proxy_routes
37+
38+
async def _fake_metadata() -> dict[str, Any]: # noqa: D401
39+
return {
40+
"issuer": proxy_server_module.UPSTREAM_BASE,
41+
"authorization_endpoint": proxy_server_module.UPSTREAM_AUTHORIZE,
42+
"token_endpoint": proxy_server_module.UPSTREAM_TOKEN,
43+
"registration_endpoint": "/register",
44+
"jwks_uri": "",
45+
}
46+
47+
monkeypatch.setattr(proxy_routes, "fetch_upstream_metadata", _fake_metadata, raising=True)
48+
return proxy_server_module
49+
50+
51+
@pytest.fixture
52+
def app(proxy_server):
53+
"""Return the Starlette ASGI app for tests."""
54+
return proxy_server.mcp.streamable_http_app()
55+
56+
57+
@pytest.fixture
58+
async def client(app) -> AsyncGenerator[httpx.AsyncClient, None]:
59+
"""Async HTTP client bound to the in-memory ASGI application."""
60+
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as c:
61+
yield c
62+
63+
64+
# ---------------------------------------------------------------------------
65+
# HTTP endpoint tests
66+
# ---------------------------------------------------------------------------
67+
68+
69+
@pytest.mark.anyio
70+
async def test_metadata_endpoint(client):
71+
r = await client.get("/.well-known/oauth-authorization-server")
72+
assert r.status_code == 200
73+
data = r.json()
74+
assert "issuer" in data
75+
assert data["authorization_endpoint"].endswith("/authorize")
76+
assert data["token_endpoint"].endswith("/token")
77+
assert data["registration_endpoint"].endswith("/register")
78+
79+
80+
@pytest.mark.anyio
81+
async def test_registration_endpoint(client, proxy_server):
82+
payload = {"redirect_uris": ["https://client.example.com/callback"]}
83+
r = await client.post("/register", json=payload)
84+
assert r.status_code == 201
85+
body = r.json()
86+
assert body["client_id"] == proxy_server.CLIENT_ID
87+
assert body["redirect_uris"] == payload["redirect_uris"]
88+
# client_secret may be None, but the field should exist (masked or real)
89+
assert "client_secret" in body
90+
91+
92+
@pytest.mark.anyio
93+
async def test_authorize_redirect(client, proxy_server):
94+
params = {
95+
"response_type": "code",
96+
"state": "xyz",
97+
"redirect_uri": "https://client.example.com/callback",
98+
"client_id": proxy_server.CLIENT_ID,
99+
"code_challenge": "testchallenge",
100+
"code_challenge_method": "S256",
101+
}
102+
r = await client.get("/authorize", params=params, follow_redirects=False)
103+
assert r.status_code in {302, 307}
104+
105+
location = r.headers["location"]
106+
parsed = urllib.parse.urlparse(location)
107+
assert parsed.scheme.startswith("http")
108+
assert parsed.netloc == urllib.parse.urlparse(proxy_server.UPSTREAM_AUTHORIZE).netloc
109+
110+
qs = urllib.parse.parse_qs(parsed.query)
111+
# Proxy should inject client_id & default scope
112+
assert qs["client_id"][0] == proxy_server.CLIENT_ID
113+
assert "scope" in qs
114+
# Original params preserved
115+
assert qs["state"][0] == "xyz"
116+
117+
118+
@pytest.mark.anyio
119+
async def test_revoke_proxy(client, monkeypatch, proxy_server):
120+
original_post = httpx.AsyncClient.post
121+
122+
async def _mock_post(self, url, data=None, timeout=10, **kwargs): # noqa: D401
123+
if url.endswith("/revoke"):
124+
return httpx.Response(200, json={"revoked": True})
125+
# For the test client's own request to /revoke, delegate to original implementation
126+
return await original_post(self, url, data=data, timeout=timeout, **kwargs)
127+
128+
monkeypatch.setattr(httpx.AsyncClient, "post", _mock_post, raising=True)
129+
130+
r = await client.post("/revoke", data={"token": "dummy"})
131+
assert r.status_code == 200
132+
assert r.json() == {"revoked": True}
133+
134+
135+
@pytest.mark.anyio
136+
async def test_token_passthrough(client, monkeypatch, proxy_server):
137+
"""Ensure /token is proxied unchanged and response is returned verbatim."""
138+
139+
# Capture outgoing POSTs made by ProxyTokenHandler
140+
captured: dict[str, Any] = {}
141+
142+
original_post = httpx.AsyncClient.post
143+
144+
async def _mock_post(self, url, *args, **kwargs): # noqa: D401
145+
if str(url).startswith(proxy_server.UPSTREAM_TOKEN):
146+
# Record exactly what was sent upstream
147+
captured["url"] = str(url)
148+
captured["data"] = kwargs.get("data")
149+
# Return a dummy upstream response
150+
return httpx.Response(
151+
200,
152+
json={
153+
"access_token": "xyz",
154+
"token_type": "bearer",
155+
"expires_in": 3600,
156+
},
157+
)
158+
# Delegate any other POSTs to the real implementation
159+
return await original_post(self, url, *args, **kwargs)
160+
161+
monkeypatch.setattr(httpx.AsyncClient, "post", _mock_post, raising=True)
162+
163+
# ---------------- Act ----------------
164+
form = {
165+
"grant_type": "authorization_code",
166+
"code": "dummy-code",
167+
"client_id": proxy_server.CLIENT_ID,
168+
}
169+
r = await client.post("/token", data=form)
170+
171+
# ---------------- Assert -------------
172+
assert r.status_code == 200
173+
assert r.json()["access_token"] == "xyz"
174+
175+
# Verify the request payload was forwarded without modification
176+
assert captured["data"] == form
177+
178+
179+
# ---------------------------------------------------------------------------
180+
# Tool invocation – user_info
181+
# ---------------------------------------------------------------------------
182+
183+
184+
@pytest.mark.anyio
185+
async def test_user_info_tool(monkeypatch, proxy_server):
186+
"""Call the `user_info` tool directly with a mocked access token."""
187+
# Craft a dummy JWT with useful claims (header/payload/signature parts)
188+
payload = (
189+
base64.urlsafe_b64encode(
190+
json.dumps(
191+
{
192+
"sub": "test-user",
193+
"preferred_username": "tester",
194+
}
195+
).encode()
196+
)
197+
.decode()
198+
.rstrip("=")
199+
)
200+
dummy_token = f"header.{payload}.signature"
201+
202+
from mcp.server.auth.middleware import auth_context
203+
from mcp.server.auth.provider import AccessToken # local import to avoid cycles
204+
205+
def _fake_get_access_token(): # noqa: D401
206+
return AccessToken(token=dummy_token, client_id="client123", scopes=["openid"], expires_at=None)
207+
208+
monkeypatch.setattr(auth_context, "get_access_token", _fake_get_access_token, raising=True)
209+
210+
result = await proxy_server.mcp.call_tool("user_info", {})
211+
212+
# call_tool returns (content_blocks, raw_result)
213+
if isinstance(result, tuple):
214+
_, raw = result
215+
else:
216+
raw = result # fallback
217+
218+
assert raw["authenticated"] is True
219+
assert ("userid" in raw and raw["userid"] == "test-user") or ("user_id" in raw and raw["user_id"] == "test-user")
220+
assert raw["username"] == "tester"

0 commit comments

Comments
 (0)