Skip to content

feat: Add support of WebSocket transport #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions examples/10_websocket_transport_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Example demonstrating WebSocket transport for FastAPI-MCP.

This example shows how to mount an MCP server with WebSocket transport,
which enables bidirectional real-time communication between MCP clients and your FastAPI application.
"""

from fastapi import FastAPI
from fastapi_mcp import FastApiMCP

# Create a simple FastAPI app with some endpoints
app = FastAPI(title="WebSocket MCP Example", description="Example using WebSocket transport for MCP")


# Define some sample endpoints that will become MCP tools
@app.get("/hello")
async def get_hello(name: str = "World"):
"""Say hello to someone."""
return {"message": f"Hello, {name}!"}


@app.post("/messages")
async def create_message(content: str, author: str = "Anonymous"):
"""Create a new message."""
return {"id": 1, "content": content, "author": author, "timestamp": "2025-01-01T12:00:00Z"}


@app.get("/status")
async def get_status():
"""Get the application status."""
return {"status": "healthy", "transport": "websocket"}


# Create MCP server with WebSocket transport
mcp = FastApiMCP(
app,
name="WebSocket MCP Example",
description="An example MCP server using WebSocket transport for real-time communication",
)

# Mount the MCP server with WebSocket transport
mcp.mount(transport="websocket")

if __name__ == "__main__":
import uvicorn

print("Starting FastAPI server with WebSocket MCP transport...")
print("WebSocket MCP endpoint available at: ws://127.0.0.1:8000/mcp")
print("You can connect MCP clients to this WebSocket endpoint.")

uvicorn.run(app, host="127.0.0.1", port=8000)
78 changes: 78 additions & 0 deletions examples/11_multi_transport_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Example demonstrating both SSE and WebSocket transports for FastAPI-MCP.

This example shows how to mount MCP servers with different transports on the same FastAPI application.
SSE is good for simple request-response patterns, while WebSocket provides real-time bidirectional communication.
"""

from fastapi import FastAPI
from fastapi_mcp import FastApiMCP

# Create a FastAPI app
app = FastAPI(
title="Multi-Transport MCP Example", description="Example showing both SSE and WebSocket transports for MCP"
)


# Define some sample endpoints
@app.get("/echo")
async def echo_message(message: str):
"""Echo back a message."""
return {"echo": message, "length": len(message)}


@app.get("/calculate")
async def calculate(operation: str, a: float, b: float):
"""Perform basic arithmetic operations."""
if operation == "add":
result = a + b
elif operation == "subtract":
result = a - b
elif operation == "multiply":
result = a * b
elif operation == "divide":
if b == 0:
raise ValueError("Division by zero is not allowed")
result = a / b
else:
raise ValueError(f"Unsupported operation: {operation}")

return {"operation": operation, "a": a, "b": b, "result": result}


@app.get("/time")
async def get_current_time():
"""Get current timestamp."""
from datetime import datetime

return {"timestamp": datetime.now().isoformat()}


# Create MCP server instances for different transports
sse_mcp = FastApiMCP(
app,
name="SSE MCP Server",
description="MCP server using Server-Sent Events transport",
)

websocket_mcp = FastApiMCP(
app,
name="WebSocket MCP Server",
description="MCP server using WebSocket transport",
)

# Mount both transports on different paths
sse_mcp.mount(mount_path="/mcp-sse", transport="sse")
websocket_mcp.mount(mount_path="/mcp-ws", transport="websocket")

if __name__ == "__main__":
import uvicorn

print("Starting FastAPI server with both SSE and WebSocket MCP transports...")
print("SSE MCP endpoint available at: http://127.0.0.1:8000/mcp-sse")
print("WebSocket MCP endpoint available at: ws://127.0.0.1:8000/mcp-ws")
print()
print("Use SSE for simple request-response patterns.")
print("Use WebSocket for real-time bidirectional communication.")

uvicorn.run(app, host="127.0.0.1", port=8000)
45 changes: 37 additions & 8 deletions fastapi_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from typing import Dict, Optional, Any, List, Union, Callable, Awaitable, Iterable, Literal, Sequence
from typing_extensions import Annotated, Doc

from fastapi import FastAPI, Request, APIRouter, params
from fastapi import FastAPI, Request, APIRouter, params, WebSocket
from fastapi.openapi.utils import get_openapi
from mcp.server.lowlevel.server import Server
import mcp.types as types

from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools
from fastapi_mcp.transport.sse import FastApiSseTransport
from fastapi_mcp.transport.websocket import FastApiWebSocketTransport
from fastapi_mcp.types import HTTPRequestInfo, AuthConfig

import logging
Expand Down Expand Up @@ -230,6 +231,32 @@ def _register_mcp_endpoints_sse(
self._register_mcp_connection_endpoint_sse(router, transport, mount_path, dependencies)
self._register_mcp_messages_endpoint_sse(router, transport, mount_path, dependencies)

def _register_mcp_connection_endpoint_websocket(
self,
router: FastAPI | APIRouter,
transport: FastApiWebSocketTransport,
mount_path: str,
dependencies: Optional[Sequence[params.Depends]],
):
@router.websocket(mount_path, dependencies=dependencies)
async def handle_mcp_websocket(websocket: WebSocket):
async with transport.connect_websocket(websocket) as (reader, writer):
await self.server.run(
reader,
writer,
self.server.create_initialization_options(notification_options=None, experimental_capabilities={}),
raise_exceptions=False,
)

def _register_mcp_endpoints_websocket(
self,
router: FastAPI | APIRouter,
transport: FastApiWebSocketTransport,
mount_path: str,
dependencies: Optional[Sequence[params.Depends]],
):
self._register_mcp_connection_endpoint_websocket(router, transport, mount_path, dependencies)

def _setup_auth_2025_03_26(self):
from fastapi_mcp.auth.proxy import (
setup_oauth_custom_metadata,
Expand Down Expand Up @@ -304,10 +331,10 @@ def mount(
),
] = "/mcp",
transport: Annotated[
Literal["sse"],
Literal["sse", "websocket"],
Doc(
"""
The transport type for the MCP server. Currently only 'sse' is supported.
The transport type for the MCP server. Supports 'sse' and 'websocket'.
"""
),
] = "sse",
Expand Down Expand Up @@ -335,15 +362,17 @@ def mount(
else:
raise ValueError(f"Invalid router type: {type(router)}")

messages_path = f"{base_path}{mount_path}/messages/"

sse_transport = FastApiSseTransport(messages_path)

# Create transport based on the specified type
dependencies = self._auth_config.dependencies if self._auth_config else None

if transport == "sse":
messages_path = f"{base_path}{mount_path}/messages/"
sse_transport = FastApiSseTransport(messages_path)
self._register_mcp_endpoints_sse(router, sse_transport, mount_path, dependencies)
else: # pragma: no cover
elif transport == "websocket":
websocket_transport = FastApiWebSocketTransport()
self._register_mcp_endpoints_websocket(router, websocket_transport, mount_path, dependencies)
else:
raise ValueError(f"Invalid transport: {transport}") # pragma: no cover

self._setup_auth()
Expand Down
4 changes: 4 additions & 0 deletions fastapi_mcp/transport/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .sse import FastApiSseTransport
from .websocket import FastApiWebSocketTransport

__all__ = ["FastApiSseTransport", "FastApiWebSocketTransport"]
132 changes: 132 additions & 0 deletions fastapi_mcp/transport/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import logging
from contextlib import asynccontextmanager

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from fastapi import WebSocket, WebSocketDisconnect
from pydantic import ValidationError

import mcp.types as types
from mcp.shared.message import SessionMessage
from fastapi_mcp.types import HTTPRequestInfo

logger = logging.getLogger(__name__)


class FastApiWebSocketTransport:
"""
WebSocket transport for FastAPI MCP that integrates with FastAPI's WebSocket support.

This transport provides similar functionality to the SSE transport but uses WebSockets
for bidirectional communication, which can be more efficient for interactive applications.
"""

@asynccontextmanager
async def connect_websocket(self, websocket: WebSocket):
"""
Connect a WebSocket and return read/write streams for MCP communication.

Args:
websocket: FastAPI WebSocket instance

Yields:
tuple: (read_stream, write_stream) for MCP communication
"""
await websocket.accept(subprotocol="mcp")

read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]

write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

async def ws_reader():
"""Read messages from WebSocket and send to read stream."""
try:
async with read_stream_writer:
while True:
try:
# Receive text message from WebSocket
raw_message = await websocket.receive_text()
logger.debug(f"Received WebSocket message: {raw_message}")

# Parse and validate JSON-RPC message
try:
client_message = types.JSONRPCMessage.model_validate_json(raw_message)

# HACK to inject HTTP request info into the MCP message,
# similar to what we do in SSE transport
if hasattr(client_message.root, "params") and client_message.root.params is not None:
# For WebSocket, we have less HTTP context, but we can still provide some info
client_message.root.params["_http_request_info"] = HTTPRequestInfo(
method="WEBSOCKET",
path=websocket.url.path,
headers=dict(websocket.headers) if websocket.headers else {},
cookies={}, # WebSocket doesn't have cookies in the same way
query_params=dict(websocket.query_params) if websocket.query_params else {},
body="", # WebSocket doesn't have a body
).model_dump(mode="json")

session_message = SessionMessage(client_message)
await read_stream_writer.send(session_message)

except ValidationError as exc:
logger.error(f"Failed to parse WebSocket message: {exc}")
await read_stream_writer.send(exc)

except WebSocketDisconnect:
logger.debug("WebSocket disconnected")
break
except anyio.get_cancelled_exc_class():
logger.debug("WebSocket reader task cancelled")
break
except Exception as e:
logger.error(f"Error reading from WebSocket: {e}")
break

except anyio.ClosedResourceError:
logger.debug("Read stream closed")
except anyio.get_cancelled_exc_class():
logger.debug("WebSocket reader cancelled")
except Exception as e:
logger.error(f"Error in WebSocket reader: {e}")

async def ws_writer():
"""Read messages from write stream and send to WebSocket."""
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
try:
# Convert message to JSON and send via WebSocket
message_json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
logger.debug(f"Sending WebSocket message: {message_json}")
await websocket.send_text(message_json)

except WebSocketDisconnect:
logger.debug("WebSocket disconnected during send")
break
except anyio.get_cancelled_exc_class():
logger.debug("WebSocket writer task cancelled")
break
except Exception as e:
logger.error(f"Error sending WebSocket message: {e}")
break

except anyio.ClosedResourceError:
logger.debug("Write stream closed")
except anyio.get_cancelled_exc_class():
logger.debug("WebSocket writer cancelled")
except Exception as e:
logger.error(f"Error in WebSocket writer: {e}")

async with anyio.create_task_group() as tg:
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)
try:
yield (read_stream, write_stream)
finally:
# Cancel the task group when the context manager exits
tg.cancel_scope.cancel()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"httpx>=0.24.0",
"requests>=2.25.0",
"tomli>=2.2.1",
"websockets>=15.0.1",
]

[dependency-groups]
Expand Down
Loading