diff --git a/fastapi_mcp/server.py b/fastapi_mcp/server.py index f5c4fc6..eea94ac 100644 --- a/fastapi_mcp/server.py +++ b/fastapi_mcp/server.py @@ -299,7 +299,9 @@ def mount( str, Doc( """ - Path where the MCP server will be mounted. Defaults to '/mcp'. + Path where the MCP server will be mounted. + Mount path is appended to the root path of FastAPI router, or to the prefix of APIRouter. + Defaults to '/mcp'. """ ), ] = "/mcp", @@ -328,14 +330,9 @@ def mount( router = self.fastapi # Build the base path correctly for the SSE transport - if isinstance(router, FastAPI): - base_path = router.root_path - elif isinstance(router, APIRouter): - base_path = self.fastapi.root_path + router.prefix - else: - raise ValueError(f"Invalid router type: {type(router)}") - - messages_path = f"{base_path}{mount_path}/messages/" + assert isinstance(router, (FastAPI, APIRouter)), f"Invalid router type: {type(router)}" + base_path = mount_path if isinstance(router, FastAPI) else router.prefix + mount_path + messages_path = f"{base_path}/messages/" sse_transport = FastApiSseTransport(messages_path) diff --git a/tests/fixtures/complex_app.py b/tests/fixtures/complex_app.py index 72ba14d..d248308 100644 --- a/tests/fixtures/complex_app.py +++ b/tests/fixtures/complex_app.py @@ -4,6 +4,8 @@ from fastapi import FastAPI, Query, Path, Body, Header, Cookie import pytest +from tests.fixtures.conftest import make_fastapi_app_base + from .types import ( Product, Customer, @@ -19,12 +21,9 @@ def make_complex_fastapi_app( example_product: Product, example_customer: Customer, example_order_response: OrderResponse, + parametrized_config: dict[str, Any] | None = None, ) -> FastAPI: - app = FastAPI( - title="Complex E-Commerce API", - description="A more complex API with nested models and various schemas", - version="1.0.0", - ) + app = make_fastapi_app_base(parametrized_config=parametrized_config) @app.get( "/products", diff --git a/tests/fixtures/conftest.py b/tests/fixtures/conftest.py new file mode 100644 index 0000000..ca0635a --- /dev/null +++ b/tests/fixtures/conftest.py @@ -0,0 +1,12 @@ +from typing import Any +from fastapi import FastAPI + + +def make_fastapi_app_base(parametrized_config: dict[str, Any] | None = None) -> FastAPI: + fastapi_config: dict[str, Any] = { + "title": "Test API", + "description": "A test API app for unit testing", + "version": "0.1.0", + } + app = FastAPI(**fastapi_config | parametrized_config if parametrized_config is not None else {}) + return app diff --git a/tests/fixtures/simple_app.py b/tests/fixtures/simple_app.py index 2b21872..5d8298a 100644 --- a/tests/fixtures/simple_app.py +++ b/tests/fixtures/simple_app.py @@ -1,18 +1,15 @@ -from typing import Optional, List +from typing import Optional, List, Any from fastapi import FastAPI, Query, Path, Body, HTTPException import pytest -from .types import Item +from tests.fixtures.conftest import make_fastapi_app_base +from .types import Item -def make_simple_fastapi_app() -> FastAPI: - app = FastAPI( - title="Test API", - description="A test API app for unit testing", - version="0.1.0", - ) +def make_simple_fastapi_app(parametrized_config: dict[str, Any] | None = None) -> FastAPI: + app = make_fastapi_app_base(parametrized_config=parametrized_config) items = [ Item(id=1, name="Item 1", price=10.0, tags=["tag1", "tag2"], description="Item 1 description"), Item(id=2, name="Item 2", price=20.0, tags=["tag2", "tag3"]), @@ -70,3 +67,8 @@ async def raise_error() -> None: @pytest.fixture def simple_fastapi_app() -> FastAPI: return make_simple_fastapi_app() + + +@pytest.fixture +def simple_fastapi_app_with_root_path() -> FastAPI: + return make_simple_fastapi_app(parametrized_config={"root_path": "/api/v1"}) diff --git a/tests/test_sse_real_transport.py b/tests/test_sse_real_transport.py index 1ac307c..408e117 100644 --- a/tests/test_sse_real_transport.py +++ b/tests/test_sse_real_transport.py @@ -9,6 +9,7 @@ import threading import coverage from typing import AsyncGenerator, Generator +from fastapi import FastAPI from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp import InitializeResult @@ -18,26 +19,12 @@ import uvicorn from fastapi_mcp import FastApiMCP -from .fixtures.simple_app import make_simple_fastapi_app - HOST = "127.0.0.1" SERVER_NAME = "Test MCP Server" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind((HOST, 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://{HOST}:{server_port}" - - -def run_server(server_port: int) -> None: +def run_server(server_port: int, fastapi_app: FastAPI) -> None: # Initialize coverage for subprocesses cov = None if "COVERAGE_PROCESS_START" in os.environ: @@ -72,16 +59,15 @@ def periodic_save(): save_thread.start() # Configure the server - fastapi = make_simple_fastapi_app() mcp = FastApiMCP( - fastapi, + fastapi_app, name=SERVER_NAME, description="Test description", ) mcp.mount() # Start the server - server = uvicorn.Server(config=uvicorn.Config(app=fastapi, host=HOST, port=server_port, log_level="error")) + server = uvicorn.Server(config=uvicorn.Config(app=fastapi_app, host=HOST, port=server_port, log_level="error")) server.run() # Give server time to start @@ -94,13 +80,24 @@ def periodic_save(): cov.save() -@pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: +@pytest.fixture(params=["simple_fastapi_app", "simple_fastapi_app_with_root_path"]) +def server(request: pytest.FixtureRequest) -> Generator[str, None, None]: # Ensure COVERAGE_PROCESS_START is set in the environment for subprocesses coverage_rc = os.path.abspath(".coveragerc") os.environ["COVERAGE_PROCESS_START"] = coverage_rc - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) + # Get a free port + with socket.socket() as s: + s.bind((HOST, 0)) + server_port = s.getsockname()[1] + + # Run the server in a subprocess + fastapi_app = request.getfixturevalue(request.param) + proc = multiprocessing.Process( + target=run_server, + kwargs={"server_port": server_port, "fastapi_app": fastapi_app}, + daemon=True, + ) proc.start() # Wait for server to be running @@ -117,7 +114,8 @@ def server(server_port: int) -> Generator[None, None, None]: else: raise RuntimeError(f"Server failed to start after {max_attempts} attempts") - yield + # Return the server URL + yield f"http://{HOST}:{server_port}{fastapi_app.root_path}" # Signal the server to stop - added graceful shutdown before kill try: @@ -134,8 +132,8 @@ def server(server_port: int) -> Generator[None, None, None]: @pytest.fixture() -async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: - async with httpx.AsyncClient(base_url=server_url) as client: +async def http_client(server: str) -> AsyncGenerator[httpx.AsyncClient, None]: + async with httpx.AsyncClient(base_url=server) as client: yield client @@ -165,8 +163,8 @@ async def connection_test() -> None: @pytest.mark.anyio -async def test_sse_basic_connection(server: None, server_url: str) -> None: - async with sse_client(server_url + "/mcp") as streams: +async def test_sse_basic_connection(server: str) -> None: + async with sse_client(server + "/mcp") as streams: async with ClientSession(*streams) as session: # Test initialization result = await session.initialize() @@ -179,8 +177,8 @@ async def test_sse_basic_connection(server: None, server_url: str) -> None: @pytest.mark.anyio -async def test_sse_tool_call(server: None, server_url: str) -> None: - async with sse_client(server_url + "/mcp") as streams: +async def test_sse_tool_call(server: str) -> None: + async with sse_client(server + "/mcp") as streams: async with ClientSession(*streams) as session: await session.initialize()