|
| 1 | +# pyright: reportMissingImports=false |
| 2 | +# pytest test suite for examples/servers/proxy_oauth/server.py |
| 3 | +# These tests spin up the FastMCP Starlette application in-process and |
| 4 | +# exercise the custom HTTP routes as well as the `user_info` tool. |
| 5 | + |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +import base64 |
| 9 | +import json |
| 10 | +import urllib.parse |
| 11 | +from collections.abc import AsyncGenerator |
| 12 | +from typing import Any |
| 13 | + |
| 14 | +import httpx # type: ignore |
| 15 | +import pytest # type: ignore |
| 16 | + |
| 17 | + |
| 18 | +@pytest.fixture |
| 19 | +def proxy_server(monkeypatch): |
| 20 | + """Import the proxy OAuth demo server with safe environment + stubs.""" |
| 21 | + import os |
| 22 | + |
| 23 | + # Avoid real outbound calls by pretending the upstream endpoints were |
| 24 | + # supplied explicitly via env vars – this makes `fetch_upstream_metadata` |
| 25 | + # construct metadata locally instead of performing an HTTP GET. |
| 26 | + os.environ.setdefault("UPSTREAM_AUTHORIZATION_ENDPOINT", "https://upstream.example.com/authorize") |
| 27 | + os.environ.setdefault("UPSTREAM_TOKEN_ENDPOINT", "https://upstream.example.com/token") |
| 28 | + os.environ.setdefault("UPSTREAM_JWKS_URI", "https://upstream.example.com/jwks") |
| 29 | + os.environ.setdefault("UPSTREAM_CLIENT_ID", "client123") |
| 30 | + os.environ.setdefault("UPSTREAM_CLIENT_SECRET", "secret123") |
| 31 | + |
| 32 | + # Deferred import so the env vars above are in effect. |
| 33 | + from examples.servers.proxy_oauth import server as proxy_server_module |
| 34 | + |
| 35 | + # Stub library-level fetch_upstream_metadata to avoid network I/O. |
| 36 | + from mcp.server.auth.proxy import routes as proxy_routes |
| 37 | + |
| 38 | + async def _fake_metadata() -> dict[str, Any]: # noqa: D401 |
| 39 | + return { |
| 40 | + "issuer": proxy_server_module.UPSTREAM_BASE, |
| 41 | + "authorization_endpoint": proxy_server_module.UPSTREAM_AUTHORIZE, |
| 42 | + "token_endpoint": proxy_server_module.UPSTREAM_TOKEN, |
| 43 | + "registration_endpoint": "/register", |
| 44 | + "jwks_uri": "", |
| 45 | + } |
| 46 | + |
| 47 | + monkeypatch.setattr(proxy_routes, "fetch_upstream_metadata", _fake_metadata, raising=True) |
| 48 | + return proxy_server_module |
| 49 | + |
| 50 | + |
| 51 | +@pytest.fixture |
| 52 | +def app(proxy_server): |
| 53 | + """Return the Starlette ASGI app for tests.""" |
| 54 | + return proxy_server.mcp.streamable_http_app() |
| 55 | + |
| 56 | + |
| 57 | +@pytest.fixture |
| 58 | +async def client(app) -> AsyncGenerator[httpx.AsyncClient, None]: |
| 59 | + """Async HTTP client bound to the in-memory ASGI application.""" |
| 60 | + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as c: |
| 61 | + yield c |
| 62 | + |
| 63 | + |
| 64 | +# --------------------------------------------------------------------------- |
| 65 | +# HTTP endpoint tests |
| 66 | +# --------------------------------------------------------------------------- |
| 67 | + |
| 68 | + |
| 69 | +@pytest.mark.anyio |
| 70 | +async def test_metadata_endpoint(client): |
| 71 | + r = await client.get("/.well-known/oauth-authorization-server") |
| 72 | + assert r.status_code == 200 |
| 73 | + data = r.json() |
| 74 | + assert "issuer" in data |
| 75 | + assert data["authorization_endpoint"].endswith("/authorize") |
| 76 | + assert data["token_endpoint"].endswith("/token") |
| 77 | + assert data["registration_endpoint"].endswith("/register") |
| 78 | + |
| 79 | + |
| 80 | +@pytest.mark.anyio |
| 81 | +async def test_registration_endpoint(client, proxy_server): |
| 82 | + payload = {"redirect_uris": ["https://client.example.com/callback"]} |
| 83 | + r = await client.post("/register", json=payload) |
| 84 | + assert r.status_code == 201 |
| 85 | + body = r.json() |
| 86 | + assert body["client_id"] == proxy_server.CLIENT_ID |
| 87 | + assert body["redirect_uris"] == payload["redirect_uris"] |
| 88 | + # client_secret may be None, but the field should exist (masked or real) |
| 89 | + assert "client_secret" in body |
| 90 | + |
| 91 | + |
| 92 | +@pytest.mark.anyio |
| 93 | +async def test_authorize_redirect(client, proxy_server): |
| 94 | + params = { |
| 95 | + "response_type": "code", |
| 96 | + "state": "xyz", |
| 97 | + "redirect_uri": "https://client.example.com/callback", |
| 98 | + "client_id": proxy_server.CLIENT_ID, |
| 99 | + "code_challenge": "testchallenge", |
| 100 | + "code_challenge_method": "S256", |
| 101 | + } |
| 102 | + r = await client.get("/authorize", params=params, follow_redirects=False) |
| 103 | + assert r.status_code in {302, 307} |
| 104 | + |
| 105 | + location = r.headers["location"] |
| 106 | + parsed = urllib.parse.urlparse(location) |
| 107 | + assert parsed.scheme.startswith("http") |
| 108 | + assert parsed.netloc == urllib.parse.urlparse(proxy_server.UPSTREAM_AUTHORIZE).netloc |
| 109 | + |
| 110 | + qs = urllib.parse.parse_qs(parsed.query) |
| 111 | + # Proxy should inject client_id & default scope |
| 112 | + assert qs["client_id"][0] == proxy_server.CLIENT_ID |
| 113 | + assert "scope" in qs |
| 114 | + # Original params preserved |
| 115 | + assert qs["state"][0] == "xyz" |
| 116 | + |
| 117 | + |
| 118 | +@pytest.mark.anyio |
| 119 | +async def test_revoke_proxy(client, monkeypatch, proxy_server): |
| 120 | + original_post = httpx.AsyncClient.post |
| 121 | + |
| 122 | + async def _mock_post(self, url, data=None, timeout=10, **kwargs): # noqa: D401 |
| 123 | + if url.endswith("/revoke"): |
| 124 | + return httpx.Response(200, json={"revoked": True}) |
| 125 | + # For the test client's own request to /revoke, delegate to original implementation |
| 126 | + return await original_post(self, url, data=data, timeout=timeout, **kwargs) |
| 127 | + |
| 128 | + monkeypatch.setattr(httpx.AsyncClient, "post", _mock_post, raising=True) |
| 129 | + |
| 130 | + r = await client.post("/revoke", data={"token": "dummy"}) |
| 131 | + assert r.status_code == 200 |
| 132 | + assert r.json() == {"revoked": True} |
| 133 | + |
| 134 | + |
| 135 | +@pytest.mark.anyio |
| 136 | +async def test_token_passthrough(client, monkeypatch, proxy_server): |
| 137 | + """Ensure /token is proxied unchanged and response is returned verbatim.""" |
| 138 | + |
| 139 | + # Capture outgoing POSTs made by ProxyTokenHandler |
| 140 | + captured: dict[str, Any] = {} |
| 141 | + |
| 142 | + original_post = httpx.AsyncClient.post |
| 143 | + |
| 144 | + async def _mock_post(self, url, *args, **kwargs): # noqa: D401 |
| 145 | + if str(url).startswith(proxy_server.UPSTREAM_TOKEN): |
| 146 | + # Record exactly what was sent upstream |
| 147 | + captured["url"] = str(url) |
| 148 | + captured["data"] = kwargs.get("data") |
| 149 | + # Return a dummy upstream response |
| 150 | + return httpx.Response( |
| 151 | + 200, |
| 152 | + json={ |
| 153 | + "access_token": "xyz", |
| 154 | + "token_type": "bearer", |
| 155 | + "expires_in": 3600, |
| 156 | + }, |
| 157 | + ) |
| 158 | + # Delegate any other POSTs to the real implementation |
| 159 | + return await original_post(self, url, *args, **kwargs) |
| 160 | + |
| 161 | + monkeypatch.setattr(httpx.AsyncClient, "post", _mock_post, raising=True) |
| 162 | + |
| 163 | + # ---------------- Act ---------------- |
| 164 | + form = { |
| 165 | + "grant_type": "authorization_code", |
| 166 | + "code": "dummy-code", |
| 167 | + "client_id": proxy_server.CLIENT_ID, |
| 168 | + } |
| 169 | + r = await client.post("/token", data=form) |
| 170 | + |
| 171 | + # ---------------- Assert ------------- |
| 172 | + assert r.status_code == 200 |
| 173 | + assert r.json()["access_token"] == "xyz" |
| 174 | + |
| 175 | + # Verify the request payload was forwarded without modification |
| 176 | + assert captured["data"] == form |
| 177 | + |
| 178 | + |
| 179 | +# --------------------------------------------------------------------------- |
| 180 | +# Tool invocation – user_info |
| 181 | +# --------------------------------------------------------------------------- |
| 182 | + |
| 183 | + |
| 184 | +@pytest.mark.anyio |
| 185 | +async def test_user_info_tool(monkeypatch, proxy_server): |
| 186 | + """Call the `user_info` tool directly with a mocked access token.""" |
| 187 | + # Craft a dummy JWT with useful claims (header/payload/signature parts) |
| 188 | + payload = ( |
| 189 | + base64.urlsafe_b64encode( |
| 190 | + json.dumps( |
| 191 | + { |
| 192 | + "sub": "test-user", |
| 193 | + "preferred_username": "tester", |
| 194 | + } |
| 195 | + ).encode() |
| 196 | + ) |
| 197 | + .decode() |
| 198 | + .rstrip("=") |
| 199 | + ) |
| 200 | + dummy_token = f"header.{payload}.signature" |
| 201 | + |
| 202 | + from mcp.server.auth.middleware import auth_context |
| 203 | + from mcp.server.auth.provider import AccessToken # local import to avoid cycles |
| 204 | + |
| 205 | + def _fake_get_access_token(): # noqa: D401 |
| 206 | + return AccessToken(token=dummy_token, client_id="client123", scopes=["openid"], expires_at=None) |
| 207 | + |
| 208 | + monkeypatch.setattr(auth_context, "get_access_token", _fake_get_access_token, raising=True) |
| 209 | + |
| 210 | + result = await proxy_server.mcp.call_tool("user_info", {}) |
| 211 | + |
| 212 | + # call_tool returns (content_blocks, raw_result) |
| 213 | + if isinstance(result, tuple): |
| 214 | + _, raw = result |
| 215 | + else: |
| 216 | + raw = result # fallback |
| 217 | + |
| 218 | + assert raw["authenticated"] is True |
| 219 | + assert ("userid" in raw and raw["userid"] == "test-user") or ("user_id" in raw and raw["user_id"] == "test-user") |
| 220 | + assert raw["username"] == "tester" |
0 commit comments