Skip to content

Add in-memory transport #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion mcp_python/client/session.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import timedelta

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl

Expand Down Expand Up @@ -36,8 +38,15 @@ def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_timeout_seconds: timedelta | None = None,
) -> None:
super().__init__(read_stream, write_stream, ServerRequest, ServerNotification)
super().__init__(
read_stream,
write_stream,
ServerRequest,
ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)

async def initialize(self) -> InitializeResult:
from mcp_python.types import (
Expand Down
34 changes: 28 additions & 6 deletions mcp_python/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ClientNotification,
ClientRequest,
CompleteRequest,
EmptyResult,
ErrorData,
JSONRPCMessage,
ListPromptsRequest,
Expand All @@ -27,6 +28,7 @@
ListToolsRequest,
ListToolsResult,
LoggingLevel,
PingRequest,
ProgressNotification,
Prompt,
PromptReference,
Expand All @@ -52,9 +54,11 @@
class Server:
def __init__(self, name: str):
self.name = name
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {}
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {
PingRequest: _ping_handler,
}
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
logger.info(f"Initializing server '{name}'")
logger.debug(f"Initializing server '{name}'")

def create_initialization_options(self) -> types.InitializationOptions:
"""Create initialization options from this server instance."""
Expand All @@ -63,9 +67,13 @@ def pkg_version(package: str) -> str:
try:
from importlib.metadata import version

return version(package)
v = version(package)
if v is not None:
return v
except Exception:
return "unknown"
pass

return "unknown"

return types.InitializationOptions(
server_name=self.name,
Expand Down Expand Up @@ -330,6 +338,11 @@ async def run(
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
initialization_options: types.InitializationOptions,
# When True, exceptions are returned as messages to the client.
# When False, exceptions are raised, which will cause the server to shut down
# but also make tracing exceptions much easier during testing and when using
# in-process servers.
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(
Expand All @@ -349,6 +362,7 @@ async def run(
f"Dispatching request of type {type(req).__name__}"
)

token = None
try:
# Set our global state that can be retrieved via
# app.get_request_context()
Expand All @@ -360,12 +374,16 @@ async def run(
)
)
response = await handler(req)
# Reset the global state after we are done
request_ctx.reset(token)
except Exception as err:
if raise_exceptions:
raise err
response = ErrorData(
code=0, message=str(err), data=None
)
finally:
# Reset the global state after we are done
if token is not None:
request_ctx.reset(token)

await message.respond(response)
else:
Expand Down Expand Up @@ -399,3 +417,7 @@ async def run(
logger.info(
f"Warning: {warning.category.__name__}: {warning.message}"
)


async def _ping_handler(request: PingRequest) -> ServerResult:
return ServerResult(EmptyResult())
87 changes: 87 additions & 0 deletions mcp_python/shared/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
In-memory transports
"""

from contextlib import asynccontextmanager
from datetime import timedelta
from typing import AsyncGenerator

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from mcp_python.client.session import ClientSession
from mcp_python.server import Server
from mcp_python.types import JSONRPCMessage

MessageStream = tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage]
]

@asynccontextmanager
async def create_client_server_memory_streams() -> AsyncGenerator[
tuple[MessageStream, MessageStream],
None
]:
"""
Creates a pair of bidirectional memory streams for client-server communication.

Returns:
A tuple of (client_streams, server_streams) where each is a tuple of
(read_stream, write_stream)
"""
# Create streams for both directions
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
](1)

client_streams = (server_to_client_receive, client_to_server_send)
server_streams = (client_to_server_receive, server_to_client_send)

async with (
server_to_client_receive,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
):
yield client_streams, server_streams


@asynccontextmanager
async def create_connected_server_and_client_session(
server: Server,
read_timeout_seconds: timedelta | None = None,
raise_exceptions: bool = False,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
async with create_client_server_memory_streams() as (
client_streams,
server_streams,
):
client_read, client_write = client_streams
server_read, server_write = server_streams

# Create a cancel scope for the server task
async with anyio.create_task_group() as tg:
tg.start_soon(
lambda: server.run(
server_read,
server_write,
server.create_initialization_options(),
raise_exceptions=raise_exceptions,
)
)

try:
async with ClientSession(
read_stream=client_read,
write_stream=client_write,
read_timeout_seconds=read_timeout_seconds,
) as client_session:
await client_session.initialize()
yield client_session
finally:
tg.cancel_scope.cancel()
25 changes: 24 additions & 1 deletion mcp_python/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from typing import Generic, TypeVar

import anyio
import anyio.lowlevel
import httpx
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import BaseModel

Expand Down Expand Up @@ -87,13 +89,16 @@ def __init__(
write_stream: MemoryObjectSendStream[JSONRPCMessage],
receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
self._response_streams = {}
self._request_id = 0
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._read_timeout_seconds = read_timeout_seconds

self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
Expand Down Expand Up @@ -147,7 +152,25 @@ async def send_request(

await self._write_stream.send(JSONRPCMessage(jsonrpc_request))

response_or_error = await response_stream_reader.receive()
try:
with anyio.fail_after(
None if self._read_timeout_seconds is None
else self._read_timeout_seconds.total_seconds()
):
response_or_error = await response_stream_reader.receive()
except TimeoutError:
raise McpError(
ErrorData(
code=httpx.codes.REQUEST_TIMEOUT,
message=(
f"Timed out while waiting for response to "
f"{request.__class__.__name__}. Waited "
f"{self._read_timeout_seconds} seconds."
),
)

)

if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error)
else:
Expand Down
3 changes: 3 additions & 0 deletions mcp_python/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,19 @@ class ErrorData(BaseModel):

code: int
"""The error type that occurred."""

message: str
"""
A short description of the error. The message SHOULD be limited to a concise single
sentence.
"""

data: Any | None = None
"""
Additional information about the error. The value of this member is defined by the
sender (e.g. detailed error information, nested errors etc.).
"""

model_config = ConfigDict(extra="allow")


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "mcp-python"
version = "0.4.0.dev"
version = "0.5.0dev"
description = "Model Context Protocol implementation for Python"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
28 changes: 28 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
from pydantic import AnyUrl

from mcp_python.server import Server
from mcp_python.server.types import InitializationOptions
from mcp_python.types import Resource, ServerCapabilities

TEST_INITIALIZATION_OPTIONS = InitializationOptions(
server_name="my_mcp_server",
server_version="0.1.0",
capabilities=ServerCapabilities(),
)

@pytest.fixture
def mcp_server() -> Server:
server = Server(name="test_server")

@server.list_resources()
async def handle_list_resources():
return [
Resource(
uri=AnyUrl("memory://test"),
name="Test Resource",
description="A test resource"
)
]

return server
28 changes: 28 additions & 0 deletions tests/shared/test_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
from typing_extensions import AsyncGenerator

from mcp_python.client.session import ClientSession
from mcp_python.server import Server
from mcp_python.shared.memory import (
create_connected_server_and_client_session,
)
from mcp_python.types import (
EmptyResult,
)


@pytest.fixture
async def client_connected_to_server(
mcp_server: Server,
) -> AsyncGenerator[ClientSession, None]:
async with create_connected_server_and_client_session(mcp_server) as client_session:
yield client_session


@pytest.mark.anyio
async def test_memory_server_and_client_connection(
client_connected_to_server: ClientSession,
):
"""Shows how a client and server can communicate over memory streams."""
response = await client_connected_to_server.send_ping()
assert isinstance(response, EmptyResult)
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.