From d5127dc78cf3e03f230274c255ffc364e7b1e226 Mon Sep 17 00:00:00 2001 From: Braxton Nunnally Date: Thu, 8 May 2025 14:48:46 -0400 Subject: [PATCH 1/2] Added sse support for ne04j-memory server --- servers/mcp-neo4j-memory/Dockerfile | 12 +- servers/mcp-neo4j-memory/pyproject.toml | 5 +- .../src/mcp_neo4j_memory/__init__.py | 6 +- .../src/mcp_neo4j_memory/app.py | 79 +++++ .../src/mcp_neo4j_memory/main.py | 55 +++ .../src/mcp_neo4j_memory/server.py | 324 +++++++++++------- .../src/mcp_neo4j_memory/transport.py | 28 ++ .../transports/sse_transport.py | 62 ++++ .../transports/stdio_transport.py | 29 ++ 9 files changed, 468 insertions(+), 132 deletions(-) create mode 100644 servers/mcp-neo4j-memory/src/mcp_neo4j_memory/app.py create mode 100644 servers/mcp-neo4j-memory/src/mcp_neo4j_memory/main.py create mode 100644 servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transport.py create mode 100644 servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/sse_transport.py create mode 100644 servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/stdio_transport.py diff --git a/servers/mcp-neo4j-memory/Dockerfile b/servers/mcp-neo4j-memory/Dockerfile index 29a9b4e..1d7a4bd 100644 --- a/servers/mcp-neo4j-memory/Dockerfile +++ b/servers/mcp-neo4j-memory/Dockerfile @@ -10,7 +10,7 @@ RUN pip install --no-cache-dir hatchling COPY pyproject.toml /app/ # Install runtime dependencies -RUN pip install --no-cache-dir mcp>=0.10.0 neo4j>=5.26.0 +RUN pip install --no-cache-dir mcp>=0.10.0 neo4j>=5.26.0 fastapi uvicorn # Copy the source code COPY src/ /app/src/ @@ -22,7 +22,13 @@ RUN pip install --no-cache-dir -e . # Environment variables for Neo4j connection ENV NEO4J_URL="bolt://host.docker.internal:7687" ENV NEO4J_USERNAME="neo4j" -ENV NEO4J_PASSWORD="password" +ENV NEO4J_PASSWORD="neo4j_password" + +# Set transport type (can be "sse" or "stdio") +ENV MCP_TRANSPORT="sse" + +# Expose port for SSE transport +EXPOSE 8090 # Command to run the server using the package entry point -CMD ["sh", "-c", "mcp-neo4j-memory --db-url ${NEO4J_URL} --username ${NEO4J_USERNAME} --password ${NEO4J_PASSWORD}"] \ No newline at end of file +CMD ["python", "-m", "src.mcp_neo4j_memory.main"] \ No newline at end of file diff --git a/servers/mcp-neo4j-memory/pyproject.toml b/servers/mcp-neo4j-memory/pyproject.toml index c5ed6f7..283d90f 100644 --- a/servers/mcp-neo4j-memory/pyproject.toml +++ b/servers/mcp-neo4j-memory/pyproject.toml @@ -1,12 +1,13 @@ [project] name = "mcp-neo4j-memory" -version = "0.1.3" +version = "0.1.1" description = "MCP Neo4j Knowledge Graph Memory Server" readme = "README.md" requires-python = ">=3.10" dependencies = [ "mcp>=0.10.0", "neo4j>=5.26.0", + "fastapi" ] [build-system] @@ -21,7 +22,7 @@ dev-dependencies = [ ] [project.scripts] -mcp-neo4j-memory = "mcp_neo4j_memory:main" +mcp-neo4j-memory = "mcp_neo4j_memory.main:main" [tool.pytest.ini_options] pythonpath = [ diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/__init__.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/__init__.py index 123507a..39d6ae4 100644 --- a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/__init__.py +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/__init__.py @@ -8,13 +8,13 @@ def main(): """Main entry point for the package.""" parser = argparse.ArgumentParser(description='Neo4j Cypher MCP Server') parser.add_argument('--db-url', - default=os.getenv("NEO4J_URL", "bolt://localhost:7687"), + default="bolt://localhost:7687", help='Neo4j connection URL') parser.add_argument('--username', - default=os.getenv("NEO4J_USERNAME", "neo4j"), + default="neo4j", help='Neo4j username') parser.add_argument('--password', - default=os.getenv("NEO4J_PASSWORD", "password"), + default="password", help='Neo4j password') args = parser.parse_args() diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/app.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/app.py new file mode 100644 index 0000000..ff2cf04 --- /dev/null +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/app.py @@ -0,0 +1,79 @@ +import logging +import os + +from fastapi import FastAPI + +from src.mcp_neo4j_memory.server import create_mcp_server +from src.mcp_neo4j_memory.transport import TransportLayer + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("mcp_neo4j_memory") + + +async def start_server(transport_type="sse"): + """Start the MCP server with the specified transport""" + # Get environment variables for Neo4j connection + neo4j_uri = os.environ.get( + "NEO4J_URI", os.environ.get("NEO4J_URL", "bolt://host.docker.internal:7687") + ) + neo4j_user = os.environ.get("NEO4J_USERNAME", os.environ.get("NEO4J_USER", "neo4j")) + neo4j_password = os.environ.get("NEO4J_PASSWORD", "neo4j_password") + + try: + # Create the MCP server + logger.info(f"Creating MCP server with transport type: {transport_type}") + mcp_server = await create_mcp_server(neo4j_uri, neo4j_user, neo4j_password) + + # Create transport and run server + transport = TransportLayer.create_transport(transport_type) + return await transport.run_server(mcp_server) + except Exception as e: + logger.error(f"Failed to start server: {e}") + raise e + + +# For FastAPI app +app = FastAPI() + + +@app.on_event("startup") +async def startup_event(): + global app + transport_type = os.environ.get("MCP_TRANSPORT", "sse") + if transport_type.lower() == "sse": + # If using SSE, mount the SSE app to this app + sse_app = await start_server("sse") + + # Add health check endpoint + @app.get("/health") + def health_check(): + return {"status": "ok"} + + # Add root endpoint + @app.get("/") + def read_root(): + return { + "message": "Neo4j MCP Memory Server", + "transport": "SSE", + "sse_endpoint": "/sse", + } + + # Mount all routes from the SSE app + for route in sse_app.routes: + app.routes.append(route) + else: + # For stdio, just inform that this app should not be used + @app.get("/") + def read_root(): + return { + "error": "This server is configured to use stdio transport, not HTTP/SSE", + "message": "Please run this server directly from the command line", + } + + @app.get("/health") + def health_check(): + return { + "status": "error", + "message": "Server is configured for stdio transport", + } diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/main.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/main.py new file mode 100644 index 0000000..2f56648 --- /dev/null +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/main.py @@ -0,0 +1,55 @@ +import argparse +import asyncio +import os + + +async def start_stdio_server(): + """Start the server with stdio transport""" + from src.mcp_neo4j_memory.app import start_server + + await start_server("stdio") + + +def main(): + """Command-line entry point for running the MCP server""" + parser = argparse.ArgumentParser(description="MCP Neo4j Memory Server") + parser.add_argument( + "--transport", + choices=["sse", "stdio"], + default="sse", + help="Transport type (sse or stdio)", + ) + parser.add_argument("--db-url", dest="neo4j_uri", help="Neo4j database URL") + parser.add_argument("--username", dest="neo4j_username", help="Neo4j username") + parser.add_argument("--password", dest="neo4j_password", help="Neo4j password") + parser.add_argument( + "--port", + type=int, + default=8090, + help="Port for HTTP server (when using SSE transport)", + ) + + args = parser.parse_args() + + # Set environment variables from arguments + if args.neo4j_uri: + os.environ["NEO4J_URI"] = args.neo4j_uri + if args.neo4j_username: + os.environ["NEO4J_USERNAME"] = args.neo4j_username + if args.neo4j_password: + os.environ["NEO4J_PASSWORD"] = args.neo4j_password + + os.environ["MCP_TRANSPORT"] = args.transport + + if args.transport == "stdio": + # Run with stdio transport + asyncio.run(start_stdio_server()) + else: + # Run with SSE transport via FastAPI/Uvicorn + import uvicorn + + uvicorn.run("src.mcp_neo4j_memory.app:app", host="0.0.0.0", port=args.port, reload=False) + + +if __name__ == "__main__": + main() diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py index 947a263..78dd483 100644 --- a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py @@ -1,45 +1,47 @@ -import os -import logging import json -from typing import Any, Dict, List, Optional -from contextlib import asynccontextmanager +import logging +from typing import Any, Dict, List import neo4j from neo4j import GraphDatabase from pydantic import BaseModel import mcp.types as types -from mcp.server import NotificationOptions, Server -from mcp.server.models import InitializationOptions -import mcp.server.stdio +from mcp.server import Server # Set up logging -logger = logging.getLogger('mcp_neo4j_memory') +logger = logging.getLogger("mcp_neo4j_memory") logger.setLevel(logging.INFO) + # Models for our knowledge graph class Entity(BaseModel): name: str type: str observations: List[str] + class Relation(BaseModel): source: str target: str relationType: str + class KnowledgeGraph(BaseModel): entities: List[Entity] relations: List[Relation] + class ObservationAddition(BaseModel): entityName: str contents: List[str] + class ObservationDeletion(BaseModel): entityName: str observations: List[str] + class Neo4jMemory: def __init__(self, neo4j_driver): self.neo4j_driver = neo4j_driver @@ -47,7 +49,7 @@ def __init__(self, neo4j_driver): def create_fulltext_index(self): try: - # TODO , + # TODO , query = """ CREATE FULLTEXT INDEX search IF NOT EXISTS FOR (m:Memory) ON EACH [m.name, m.type, m.observations]; """ @@ -64,47 +66,49 @@ async def load_graph(self, filter_query="*"): CALL db.index.fulltext.queryNodes('search', $filter) yield node as entity, score OPTIONAL MATCH (entity)-[r]-(other) RETURN collect(distinct { - name: entity.name, - type: entity.type, + name: entity.name, + type: entity.type, observations: entity.observations }) as nodes, collect(distinct { - source: startNode(r).name, - target: endNode(r).name, + source: startNode(r).name, + target: endNode(r).name, relationType: type(r) }) as relations """ - + result = self.neo4j_driver.execute_query(query, {"filter": filter_query}) - + if not result.records: return KnowledgeGraph(entities=[], relations=[]) - + record = result.records[0] - nodes = record.get('nodes') - rels = record.get('relations') - + nodes = record.get("nodes") + rels = record.get("relations") + entities = [ Entity( - name=node.get('name'), - type=node.get('type'), - observations=node.get('observations', []) + name=node.get("name"), + type=node.get("type"), + observations=node.get("observations", []), ) - for node in nodes if node.get('name') + for node in nodes + if node.get("name") ] - + relations = [ Relation( - source=rel.get('source'), - target=rel.get('target'), - relationType=rel.get('relationType') + source=rel.get("source"), + target=rel.get("target"), + relationType=rel.get("relationType"), ) - for rel in rels if rel.get('source') and rel.get('target') and rel.get('relationType') + for rel in rels + if rel.get("source") and rel.get("target") and rel.get("relationType") ] - + logger.debug(f"Loaded entities: {entities}") logger.debug(f"Loaded relations: {relations}") - + return KnowledgeGraph(entities=entities, relations=relations) async def create_entities(self, entities: List[Entity]) -> List[Entity]: @@ -114,7 +118,7 @@ async def create_entities(self, entities: List[Entity]) -> List[Entity]: SET e += entity {.type, .observations} SET e:$(entity.type) """ - + entities_data = [entity.model_dump() for entity in entities] self.neo4j_driver.execute_query(query, {"entities": entities_data}) return entities @@ -128,29 +132,32 @@ async def create_relations(self, relations: List[Relation]) -> List[Relation]: AND to.name = relation.target MERGE (from)-[r:$(relation.relationType)]->(to) """ - + self.neo4j_driver.execute_query( - query, - {"relations": [relation.model_dump() for relation in relations]} + query, {"relations": [relation.model_dump() for relation in relations]} ) - + return relations - async def add_observations(self, observations: List[ObservationAddition]) -> List[Dict[str, Any]]: + async def add_observations( + self, observations: List[ObservationAddition] + ) -> List[Dict[str, Any]]: query = """ - UNWIND $observations as obs + UNWIND $observations as obs MATCH (e:Memory { name: obs.entityName }) WITH e, [o in obs.contents WHERE NOT o IN e.observations] as new SET e.observations = coalesce(e.observations,[]) + new RETURN e.name as name, new """ - + result = self.neo4j_driver.execute_query( - query, - {"observations": [obs.model_dump() for obs in observations]} + query, {"observations": [obs.model_dump() for obs in observations]} ) - results = [{"entityName": record.get("name"), "addedObservations": record.get("new")} for record in result.records] + results = [ + {"entityName": record.get("name"), "addedObservations": record.get("new")} + for record in result.records + ] return results async def delete_entities(self, entity_names: List[str]) -> None: @@ -159,20 +166,17 @@ async def delete_entities(self, entity_names: List[str]) -> None: MATCH (e:Memory { name: name }) DETACH DELETE e """ - + self.neo4j_driver.execute_query(query, {"entities": entity_names}) async def delete_observations(self, deletions: List[ObservationDeletion]) -> None: query = """ - UNWIND $deletions as d + UNWIND $deletions as d MATCH (e:Memory { name: d.entityName }) SET e.observations = [o in coalesce(e.observations,[]) WHERE NOT o IN d.observations] """ self.neo4j_driver.execute_query( - query, - { - "deletions": [deletion.model_dump() for deletion in deletions] - } + query, {"deletions": [deletion.model_dump() for deletion in deletions]} ) async def delete_relations(self, relations: List[Relation]) -> None: @@ -184,8 +188,7 @@ async def delete_relations(self, relations: List[Relation]) -> None: DELETE r """ self.neo4j_driver.execute_query( - query, - {"relations": [relation.model_dump() for relation in relations]} + query, {"relations": [relation.model_dump() for relation in relations]} ) async def read_graph(self) -> KnowledgeGraph: @@ -197,26 +200,24 @@ async def search_nodes(self, query: str) -> KnowledgeGraph: async def find_nodes(self, names: List[str]) -> KnowledgeGraph: return await self.load_graph("name: (" + " ".join(names) + ")") -async def main(neo4j_uri: str, neo4j_user: str, neo4j_password: str): - logger.info(f"Connecting to neo4j MCP Server with DB URL: {neo4j_uri}") + +async def create_mcp_server(neo4j_uri, neo4j_user, neo4j_password): + logger.info(f"Connecting to Neo4j at {neo4j_uri}") # Connect to Neo4j - neo4j_driver = GraphDatabase.driver( - neo4j_uri, - auth=(neo4j_user, neo4j_password) - ) - + neo4j_driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password)) + # Verify connection try: neo4j_driver.verify_connectivity() logger.info(f"Connected to Neo4j at {neo4j_uri}") except Exception as e: logger.error(f"Failed to connect to Neo4j: {e}") - exit(1) + raise e # Initialize memory memory = Neo4jMemory(neo4j_driver) - + # Create MCP server server = Server("mcp-neo4j-memory") @@ -235,12 +236,18 @@ async def handle_list_tools() -> List[types.Tool]: "items": { "type": "object", "properties": { - "name": {"type": "string", "description": "The name of the entity"}, - "type": {"type": "string", "description": "The type of the entity"}, + "name": { + "type": "string", + "description": "The name of the entity", + }, + "type": { + "type": "string", + "description": "The type of the entity", + }, "observations": { "type": "array", "items": {"type": "string"}, - "description": "An array of observation contents associated with the entity" + "description": "An array of observation contents associated with the entity", }, }, "required": ["name", "type", "observations"], @@ -261,9 +268,18 @@ async def handle_list_tools() -> List[types.Tool]: "items": { "type": "object", "properties": { - "source": {"type": "string", "description": "The name of the entity where the relation starts"}, - "target": {"type": "string", "description": "The name of the entity where the relation ends"}, - "relationType": {"type": "string", "description": "The type of the relation"}, + "source": { + "type": "string", + "description": "The name of the entity where the relation starts", + }, + "target": { + "type": "string", + "description": "The name of the entity where the relation ends", + }, + "relationType": { + "type": "string", + "description": "The type of the relation", + }, }, "required": ["source", "target", "relationType"], }, @@ -283,11 +299,14 @@ async def handle_list_tools() -> List[types.Tool]: "items": { "type": "object", "properties": { - "entityName": {"type": "string", "description": "The name of the entity to add the observations to"}, + "entityName": { + "type": "string", + "description": "The name of the entity to add the observations to", + }, "contents": { "type": "array", "items": {"type": "string"}, - "description": "An array of observation contents to add" + "description": "An array of observation contents to add", }, }, "required": ["entityName", "contents"], @@ -306,7 +325,7 @@ async def handle_list_tools() -> List[types.Tool]: "entityNames": { "type": "array", "items": {"type": "string"}, - "description": "An array of entity names to delete" + "description": "An array of entity names to delete", }, }, "required": ["entityNames"], @@ -323,11 +342,14 @@ async def handle_list_tools() -> List[types.Tool]: "items": { "type": "object", "properties": { - "entityName": {"type": "string", "description": "The name of the entity containing the observations"}, + "entityName": { + "type": "string", + "description": "The name of the entity containing the observations", + }, "observations": { "type": "array", "items": {"type": "string"}, - "description": "An array of observations to delete" + "description": "An array of observations to delete", }, }, "required": ["entityName", "observations"], @@ -348,13 +370,22 @@ async def handle_list_tools() -> List[types.Tool]: "items": { "type": "object", "properties": { - "source": {"type": "string", "description": "The name of the entity where the relation starts"}, - "target": {"type": "string", "description": "The name of the entity where the relation ends"}, - "relationType": {"type": "string", "description": "The type of the relation"}, + "source": { + "type": "string", + "description": "The name of the entity where the relation starts", + }, + "target": { + "type": "string", + "description": "The name of the entity where the relation ends", + }, + "relationType": { + "type": "string", + "description": "The type of the relation", + }, }, "required": ["source", "target", "relationType"], }, - "description": "An array of relations to delete" + "description": "An array of relations to delete", }, }, "required": ["relations"], @@ -374,14 +405,17 @@ async def handle_list_tools() -> List[types.Tool]: inputSchema={ "type": "object", "properties": { - "query": {"type": "string", "description": "The search query to match against entity names, types, and observation content"}, + "query": { + "type": "string", + "description": "The search query to match against entity names, types, and observation content", + }, }, "required": ["query"], }, ), types.Tool( name="find_nodes", - description="Find specific nodes in the knowledge graph by their names", + description="Open specific nodes in the knowledge graph by their names", inputSchema={ "type": "object", "properties": { @@ -395,18 +429,17 @@ async def handle_list_tools() -> List[types.Tool]: }, ), types.Tool( - name="open_nodes", - description="Open specific nodes in the knowledge graph by their names", + name="entity_exists", + description="Check if an entity exists in the knowledge graph by name", inputSchema={ "type": "object", "properties": { - "names": { - "type": "array", - "items": {"type": "string"}, - "description": "An array of entity names to retrieve", + "name": { + "type": "string", + "description": "The name of the entity to check", }, }, - "required": ["names"], + "required": ["name"], }, ), ] @@ -416,69 +449,112 @@ async def handle_call_tool( name: str, arguments: Dict[str, Any] | None ) -> List[types.TextContent | types.ImageContent]: try: - if name == "read_graph": - result = await memory.read_graph() - return [types.TextContent(type="text", text=json.dumps(result.model_dump(), indent=2))] - if not arguments: raise ValueError(f"No arguments provided for tool: {name}") if name == "create_entities": - entities = [Entity(**entity) for entity in arguments.get("entities", [])] + entities = [ + Entity(**entity) for entity in arguments.get("entities", []) + ] result = await memory.create_entities(entities) - return [types.TextContent(type="text", text=json.dumps([e.model_dump() for e in result], indent=2))] - + return [ + types.TextContent( + type="text", + text=json.dumps([e.model_dump() for e in result], indent=2), + ) + ] + elif name == "create_relations": - relations = [Relation(**relation) for relation in arguments.get("relations", [])] + relations = [ + Relation(**relation) for relation in arguments.get("relations", []) + ] result = await memory.create_relations(relations) - return [types.TextContent(type="text", text=json.dumps([r.model_dump() for r in result], indent=2))] - + return [ + types.TextContent( + type="text", + text=json.dumps([r.model_dump() for r in result], indent=2), + ) + ] + elif name == "add_observations": - observations = [ObservationAddition(**obs) for obs in arguments.get("observations", [])] + observations = [ + ObservationAddition(**obs) + for obs in arguments.get("observations", []) + ] result = await memory.add_observations(observations) - return [types.TextContent(type="text", text=json.dumps(result, indent=2))] - + return [ + types.TextContent(type="text", text=json.dumps(result, indent=2)) + ] + elif name == "delete_entities": await memory.delete_entities(arguments.get("entityNames", [])) - return [types.TextContent(type="text", text="Entities deleted successfully")] - + return [ + types.TextContent(type="text", text="Entities deleted successfully") + ] + elif name == "delete_observations": - deletions = [ObservationDeletion(**deletion) for deletion in arguments.get("deletions", [])] + deletions = [ + ObservationDeletion(**deletion) + for deletion in arguments.get("deletions", []) + ] await memory.delete_observations(deletions) - return [types.TextContent(type="text", text="Observations deleted successfully")] - + return [ + types.TextContent( + type="text", text="Observations deleted successfully" + ) + ] + elif name == "delete_relations": - relations = [Relation(**relation) for relation in arguments.get("relations", [])] + relations = [ + Relation(**relation) for relation in arguments.get("relations", []) + ] await memory.delete_relations(relations) - return [types.TextContent(type="text", text="Relations deleted successfully")] - + return [ + types.TextContent( + type="text", text="Relations deleted successfully" + ) + ] + + elif name == "read_graph": + result = await memory.read_graph() + return [ + types.TextContent( + type="text", text=json.dumps(result.model_dump(), indent=2) + ) + ] + elif name == "search_nodes": result = await memory.search_nodes(arguments.get("query", "")) - return [types.TextContent(type="text", text=json.dumps(result.model_dump(), indent=2))] - - elif name == "find_nodes" or name == "open_nodes": + return [ + types.TextContent( + type="text", text=json.dumps(result.model_dump(), indent=2) + ) + ] + + elif name == "find_nodes": result = await memory.find_nodes(arguments.get("names", [])) - return [types.TextContent(type="text", text=json.dumps(result.model_dump(), indent=2))] - + return [ + types.TextContent( + type="text", text=json.dumps(result.model_dump(), indent=2) + ) + ] + elif name == "entity_exists": + entity_name = arguments.get("name", "") + query = """ + MATCH (e:Memory {name: $name}) + RETURN COUNT(e) > 0 AS exists + """ + + result = neo4j_driver.execute_query(query, {"name": entity_name}) + exists = result.records[0].get("exists") if result.records else False + return [ + types.TextContent(type="text", text=json.dumps({"exists": exists})) + ] else: raise ValueError(f"Unknown tool: {name}") - + except Exception as e: logger.error(f"Error handling tool call: {e}") return [types.TextContent(type="text", text=f"Error: {str(e)}")] - # Start the server - async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): - logger.info("MCP Knowledge Graph Memory using Neo4j running on stdio") - await server.run( - read_stream, - write_stream, - InitializationOptions( - server_name="mcp-neo4j-memory", - server_version="1.1", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) + return server diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transport.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transport.py new file mode 100644 index 0000000..9f68df8 --- /dev/null +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transport.py @@ -0,0 +1,28 @@ +import logging +from abc import ABC, abstractmethod + + +class TransportLayer(ABC): + """Abstract base class for MCP transport layers""" + + def __init__(self): + self.logger = logging.getLogger("mcp_neo4j_memory") + + @abstractmethod + async def run_server(self, mcp_server): + """Run the server with the given MCP server""" + pass + + @classmethod + def create_transport(cls, transport_type="sse"): + """Factory method to create transport layer based on type""" + if transport_type.lower() == "sse": + from src.mcp_neo4j_memory.transports.sse_transport import SSETransport + + return SSETransport() + elif transport_type.lower() == "stdio": + from src.mcp_neo4j_memory.transports.stdio_transport import StdIOTransport + + return StdIOTransport() + else: + raise ValueError(f"Unsupported transport type: {transport_type}") diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/sse_transport.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/sse_transport.py new file mode 100644 index 0000000..8f2f9c4 --- /dev/null +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/sse_transport.py @@ -0,0 +1,62 @@ +from fastapi import FastAPI +from src.mcp_neo4j_memory.transport import TransportLayer +from starlette.applications import Starlette +from starlette.routing import Mount, Route + +from mcp.server import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.sse import SseServerTransport + + +class SSETransport(TransportLayer): + """Implementation of SSE transport for MCP server""" + + def __init__(self): + super().__init__() + self.app = FastAPI( + title="Neo4j MCP Memory Server", + description="MCP server for Neo4j knowledge graph memory", + ) + + async def create_app(self, mcp_server): + """Create and return a FastAPI app with the SSE transport configured""" + # Create the SSE transport + transport = SseServerTransport("/messages/") + + # Create the SSE handler + async def handle_sse(request): + async with transport.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await mcp_server.run( + streams[0], + streams[1], + InitializationOptions( + server_name="mcp-neo4j-memory", + server_version="1.1", + capabilities=mcp_server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + + # Create the Starlette routes + routes = [ + Route("/sse", endpoint=handle_sse), + Mount("/messages", app=transport.handle_post_message), + ] + + # Create Starlette app and mount it + sse_app = Starlette(routes=routes) + self.app.mount("/", sse_app) + + return self.app + + async def run_server(self, mcp_server): + """Run the server with the given MCP server""" + self.logger.info("Starting MCP server with SSE transport") + app = await self.create_app(mcp_server) + + # In production, app will be run by uvicorn + return app diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/stdio_transport.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/stdio_transport.py new file mode 100644 index 0000000..34a1810 --- /dev/null +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/stdio_transport.py @@ -0,0 +1,29 @@ +from mcp_neo4j_memory.transport import TransportLayer + +import mcp.server.stdio +from mcp.server import NotificationOptions +from mcp.server.models import InitializationOptions + + +class StdIOTransport(TransportLayer): + """Implementation of stdio transport for MCP server""" + + async def run_server(self, mcp_server): + """Run the server with stdio transport""" + self.logger.info("Starting MCP server with stdio transport") + + # Using the original implementation from server.py + async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): + self.logger.info("MCP Knowledge Graph Memory using Neo4j running on stdio") + await mcp_server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="mcp-neo4j-memory", + server_version="1.1", + capabilities=mcp_server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) From ac05ef530859cb0a002daa6d5474f9881d8dcfb5 Mon Sep 17 00:00:00 2001 From: Braxton Nunnally Date: Thu, 22 May 2025 11:19:50 -0400 Subject: [PATCH 2/2] Added sse + stdio support for cypher mcp server --- servers/mcp-neo4j-cypher/Dockerfile | 13 +- servers/mcp-neo4j-cypher/pyproject.toml | 5 +- .../src/mcp_neo4j_cypher/app.py | 108 +++++++ .../src/mcp_neo4j_cypher/main.py | 60 ++++ .../src/mcp_neo4j_cypher/server.py | 278 +++++++++++------- .../src/mcp_neo4j_cypher/transport.py | 28 ++ .../mcp_neo4j_cypher/transports/__init__.py | 0 .../transports/sse_transport.py | 62 ++++ .../transports/stdio_transport.py | 28 ++ 9 files changed, 466 insertions(+), 116 deletions(-) create mode 100644 servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/app.py create mode 100644 servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/main.py create mode 100644 servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transport.py create mode 100644 servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transports/__init__.py create mode 100644 servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transports/sse_transport.py create mode 100644 servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transports/stdio_transport.py diff --git a/servers/mcp-neo4j-cypher/Dockerfile b/servers/mcp-neo4j-cypher/Dockerfile index c5e614f..914c5cf 100644 --- a/servers/mcp-neo4j-cypher/Dockerfile +++ b/servers/mcp-neo4j-cypher/Dockerfile @@ -21,7 +21,14 @@ RUN pip install --no-cache-dir -e . # Environment variables for Neo4j connection ENV NEO4J_URL="bolt://host.docker.internal:7687" ENV NEO4J_USERNAME="neo4j" -ENV NEO4J_PASSWORD="password" +ENV NEO4J_PASSWORD="neo4j_password" +ENV NEO4J_DATABASE="neo4j" -# Command to run the server using the package entry point -CMD ["sh", "-c", "mcp-neo4j-cypher --db-url ${NEO4J_URL} --username ${NEO4J_USERNAME} --password ${NEO4J_PASSWORD}"] \ No newline at end of file +# Set transport type (can be "sse" or "stdio") +ENV MCP_TRANSPORT="sse" + +# Expose port for SSE transport +EXPOSE 8000 + +# Command to run the server using the entry point (will use transport type from MCP_TRANSPORT env var) +CMD ["python", "-m", "src.mcp_neo4j_memory.main"] \ No newline at end of file diff --git a/servers/mcp-neo4j-cypher/pyproject.toml b/servers/mcp-neo4j-cypher/pyproject.toml index 775cf86..949fe2f 100644 --- a/servers/mcp-neo4j-cypher/pyproject.toml +++ b/servers/mcp-neo4j-cypher/pyproject.toml @@ -8,6 +8,9 @@ dependencies = [ "mcp[cli]>=1.6.0", "neo4j>=5.26.0", "pydantic>=2.10.1", + "fastapi", + "uvicorn", + "pytz" ] [build-system] @@ -24,4 +27,4 @@ dev-dependencies = [ ] [project.scripts] -mcp-neo4j-cypher = "mcp_neo4j_cypher:main" +mcp-neo4j-cypher = "mcp_neo4j_cypher.main:main" diff --git a/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/app.py b/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/app.py new file mode 100644 index 0000000..b835ac4 --- /dev/null +++ b/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/app.py @@ -0,0 +1,108 @@ +import logging +import os +from datetime import datetime + +import pytz +from fastapi import FastAPI + +from src.server import create_mcp_server, create_neo4j_driver, healthcheck +from src.transport import TransportLayer + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("mcp_neo4j_cypher") + + +async def start_server(transport_type="sse"): + """Start the MCP server with the specified transport""" + # Get environment variables for Neo4j connection + neo4j_uri = os.environ.get( + "NEO4J_URI", os.environ.get("NEO4J_URL", "bolt://host.docker.internal:7687") + ) + neo4j_user = os.environ.get("NEO4J_USERNAME", os.environ.get("NEO4J_USER", "neo4j")) + neo4j_password = os.environ.get("NEO4J_PASSWORD", "neo4j_password") + neo4j_database = os.environ.get("NEO4J_DATABASE", "neo4j") + + try: + # Health check Neo4j connection + try: + healthcheck(neo4j_uri, neo4j_user, neo4j_password, neo4j_database) + except Exception as e: + logger.error(f"Failed to connect to Neo4j: {e}") + raise e + + # Create Neo4j driver + neo4j_driver = await create_neo4j_driver(neo4j_uri, neo4j_user, neo4j_password) + + # Create the MCP server + logger.info(f"Creating MCP server with transport type: {transport_type}") + mcp_server = create_mcp_server(neo4j_driver, neo4j_database) + + # Create transport and run server + transport = TransportLayer.create_transport(transport_type) + return await transport.run_server(mcp_server) + except Exception as e: + logger.error(f"Failed to start server: {e}") + raise e + + +# For FastAPI app +app = FastAPI() + + +@app.on_event("startup") +async def startup_event(): + global app + transport_type = os.environ.get("MCP_TRANSPORT", "sse") + if transport_type.lower() == "sse": + # If using SSE, mount the SSE app to this app + sse_app = await start_server("sse") + + # Add health check endpoint + @app.get("/health") + def health_check(): + neo4j_uri = os.environ.get( + "NEO4J_URI", + os.environ.get("NEO4J_URL", "bolt://host.docker.internal:7687"), + ) + neo4j_user = os.environ.get( + "NEO4J_USERNAME", os.environ.get("NEO4J_USER", "neo4j") + ) + neo4j_password = os.environ.get("NEO4J_PASSWORD", "neo4j_password") + neo4j_database = os.environ.get("NEO4J_DATABASE", "neo4j") + timestamp = datetime.now(pytz.UTC).isoformat() + + try: + healthcheck(neo4j_uri, neo4j_user, neo4j_password, neo4j_database) + return {"status": "ok", "timestamp": timestamp} + except Exception as e: + logger.error(f"Failed to connect to Neo4j: {e}") + return {"status": "connection failed", "timestamp": timestamp} + + # Add root endpoint + @app.get("/") + def read_root(): + return { + "message": "Neo4j MCP Cypher Server", + "transport": "SSE", + "sse_endpoint": "/sse", + } + + # Mount all routes from the SSE app + for route in sse_app.routes: + app.routes.append(route) + else: + # For stdio, just inform that this app should not be used + @app.get("/") + def read_root(): + return { + "error": "This server is configured to use stdio transport, not HTTP/SSE", + "message": "Please run this server directly from the command line", + } + + @app.get("/health") + def health_check(): + return { + "status": "error", + "message": "Server is configured for stdio transport", + } diff --git a/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/main.py b/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/main.py new file mode 100644 index 0000000..b6052f1 --- /dev/null +++ b/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/main.py @@ -0,0 +1,60 @@ +import argparse +import asyncio +import os + + +async def start_stdio_server(): + """Start the server with stdio transport""" + from src.app import start_server + + await start_server("stdio") + + +def main(): + """Command-line entry point for running the MCP server""" + parser = argparse.ArgumentParser(description="MCP Neo4j Cypher Server") + parser.add_argument( + "--transport", + choices=["sse", "stdio"], + default="sse", + help="Transport type (sse or stdio)", + ) + parser.add_argument("--db-url", dest="neo4j_uri", help="Neo4j database URL") + parser.add_argument("--username", dest="neo4j_username", help="Neo4j username") + parser.add_argument("--password", dest="neo4j_password", help="Neo4j password") + parser.add_argument( + "--database", dest="neo4j_database", help="Neo4j database name", default="neo4j" + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port for HTTP server (when using SSE transport)", + ) + + args = parser.parse_args() + + # Set environment variables from arguments + if args.neo4j_uri: + os.environ["NEO4J_URI"] = args.neo4j_uri + if args.neo4j_username: + os.environ["NEO4J_USERNAME"] = args.neo4j_username + if args.neo4j_password: + os.environ["NEO4J_PASSWORD"] = args.neo4j_password + if args.neo4j_database: + os.environ["NEO4J_DATABASE"] = args.neo4j_database + + os.environ["MCP_TRANSPORT"] = args.transport + + if args.transport == "stdio": + # Run with stdio transport + asyncio.run(start_stdio_server()) + else: + # Run with SSE transport via FastAPI/Uvicorn + import uvicorn + + uvicorn.run("src.app:app", host="0.0.0.0", port=args.port, reload=False) + + +if __name__ == "__main__": + main() diff --git a/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/server.py b/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/server.py index ed94e11..735fc2c 100644 --- a/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/server.py +++ b/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/server.py @@ -3,10 +3,8 @@ import re import sys import time -from typing import Any, Optional +from typing import Any -import mcp.types as types -from mcp.server.fastmcp import FastMCP from neo4j import ( AsyncDriver, AsyncGraphDatabase, @@ -15,7 +13,9 @@ GraphDatabase, ) from neo4j.exceptions import DatabaseError -from pydantic import Field + +import mcp.types as types +from mcp.server import Server logger = logging.getLogger("mcp_neo4j_cypher") @@ -23,7 +23,7 @@ def healthcheck(db_url: str, username: str, password: str, database: str) -> None: """ Confirm that Neo4j is running before continuing. - Creates a a sync Neo4j driver instance for checking connection and closes it after connection is established. + Creates a sync Neo4j driver instance for checking connection and closes it after connection is established. """ print("Confirming Neo4j is running...", file=sys.stderr) @@ -80,120 +80,174 @@ def _is_write_query(query: str) -> bool: ) -def create_mcp_server(neo4j_driver: AsyncDriver, database: str = "neo4j") -> FastMCP: - mcp: FastMCP = FastMCP("mcp-neo4j-cypher", dependencies=["neo4j", "pydantic"]) - - async def get_neo4j_schema() -> list[types.TextContent]: - """List all node, their attributes and their relationships to other nodes in the neo4j database. - If this fails with a message that includes "Neo.ClientError.Procedure.ProcedureNotFound" - suggest that the user install and enable the APOC plugin. - """ - - get_schema_query = """ -call apoc.meta.data() yield label, property, type, other, unique, index, elementType -where elementType = 'node' and not label starts with '_' -with label, - collect(case when type <> 'RELATIONSHIP' then [property, type + case when unique then " unique" else "" end + case when index then " indexed" else "" end] end) as attributes, - collect(case when type = 'RELATIONSHIP' then [property, head(other)] end) as relationships -RETURN label, apoc.map.fromPairs(attributes) as attributes, apoc.map.fromPairs(relationships) as relationships -""" - - try: - async with neo4j_driver.session(database=database) as session: - results_json_str = await session.execute_read( - _read, get_schema_query, dict() - ) - - logger.debug(f"Read query returned {len(results_json_str)} rows") - - return [types.TextContent(type="text", text=results_json_str)] - - except Exception as e: - logger.error(f"Database error retrieving schema: {e}") - return [types.TextContent(type="text", text=f"Error: {e}")] - - async def read_neo4j_cypher( - query: str = Field(..., description="The Cypher query to execute."), - params: Optional[dict[str, Any]] = Field( - None, description="The parameters to pass to the Cypher query." - ), - ) -> list[types.TextContent]: - """Execute a read Cypher query on the neo4j database.""" - - if _is_write_query(query): - raise ValueError("Only MATCH queries are allowed for read-query") - - try: - async with neo4j_driver.session(database=database) as session: - results_json_str = await session.execute_read(_read, query, params) - - logger.debug(f"Read query returned {len(results_json_str)} rows") - - return [types.TextContent(type="text", text=results_json_str)] - - except Exception as e: - logger.error(f"Database error executing query: {e}\n{query}\n{params}") - return [ - types.TextContent(type="text", text=f"Error: {e}\n{query}\n{params}") - ] - - async def write_neo4j_cypher( - query: str = Field(..., description="The Cypher query to execute."), - params: Optional[dict[str, Any]] = Field( - None, description="The parameters to pass to the Cypher query." - ), - ) -> list[types.TextContent]: - """Execute a write Cypher query on the neo4j database.""" - - if not _is_write_query(query): - raise ValueError("Only write queries are allowed for write-query") - - try: - async with neo4j_driver.session(database=database) as session: - raw_results = await session.execute_write(_write, query, params) - counters_json_str = json.dumps( - raw_results._summary.counters.__dict__, default=str - ) - - logger.debug(f"Write query affected {counters_json_str}") - - return [types.TextContent(type="text", text=counters_json_str)] - - except Exception as e: - logger.error(f"Database error executing query: {e}\n{query}\n{params}") - return [ - types.TextContent(type="text", text=f"Error: {e}\n{query}\n{params}") - ] - - mcp.add_tool(get_neo4j_schema) - mcp.add_tool(read_neo4j_cypher) - mcp.add_tool(write_neo4j_cypher) - - return mcp - - -def main( - db_url: str, - username: str, - password: str, - database: str, -) -> None: - logger.info("Starting MCP neo4j Server") - - neo4j_driver = AsyncGraphDatabase.driver( +async def create_neo4j_driver(db_url: str, username: str, password: str) -> AsyncDriver: + """ + Create and return an AsyncDriver instance for Neo4j. + """ + driver = AsyncGraphDatabase.driver( db_url, auth=( username, password, ), ) + return driver - mcp = create_mcp_server(neo4j_driver, database) - healthcheck(db_url, username, password, database) - - mcp.run(transport="stdio") +def create_mcp_server(neo4j_driver: AsyncDriver, database: str = "neo4j") -> Server: + """ + Create and return a Server instance for Neo4j Cypher queries. + """ + server = Server("mcp-neo4j-cypher") + + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + """List available tools for the Neo4j Cypher server.""" + return [ + types.Tool( + name="get_neo4j_schema", + description="List all node, their attributes and their relationships to other nodes in the neo4j database.", + inputSchema={ + "type": "object", + "properties": {}, + }, + ), + types.Tool( + name="read_neo4j_cypher", + description="Execute a read Cypher query on the neo4j database.", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The Cypher query to execute.", + }, + "params": { + "type": "object", + "description": "Optional parameters for the query.", + "additionalProperties": True, + }, + }, + "required": ["query"], + }, + ), + types.Tool( + name="write_neo4j_cypher", + description="Execute a write Cypher query on the neo4j database.", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The Cypher query to execute.", + }, + "params": { + "type": "object", + "description": "Optional parameters for the query.", + "additionalProperties": True, + }, + }, + "required": ["query"], + }, + ), + ] + + @server.call_tool() + async def handle_call_tool( + name: str, arguments: dict[str, Any] | None + ) -> list[types.TextContent | types.ImageContent]: + """Handle tool calls for the Neo4j Cypher server.""" + try: + if name == "get_neo4j_schema": + get_schema_query = """ + call apoc.meta.data() yield label, property, type, other, unique, index, elementType + where elementType = 'node' and not label starts with '_' + with label, + collect(case when type <> 'RELATIONSHIP' then [property, type + case when unique then " unique" else "" end + case when index then " indexed" else "" end] end) as attributes, + collect(case when type = 'RELATIONSHIP' then [property, head(other)] end) as relationships + RETURN label, apoc.map.fromPairs(attributes) as attributes, apoc.map.fromPairs(relationships) as relationships + """ + + try: + async with neo4j_driver.session(database=database) as session: + results_json_str = await session.execute_read( + _read, get_schema_query, {} + ) + + logger.debug(f"Read query returned {len(results_json_str)} rows") + + return [types.TextContent(type="text", text=results_json_str)] + + except Exception as e: + logger.error(f"Database error retrieving schema: {e}") + return [types.TextContent(type="text", text=f"Error: {e}")] + + elif name == "read_neo4j_cypher": + if not arguments: + raise ValueError("No arguments provided") + + query = arguments.get("query") + params = arguments.get("params", {}) + + if _is_write_query(query): + raise ValueError("Only MATCH queries are allowed for read-query") + + try: + async with neo4j_driver.session(database=database) as session: + results_json_str = await session.execute_read( + _read, query, params + ) + + logger.debug(f"Read query returned {len(results_json_str)} rows") + + return [types.TextContent(type="text", text=results_json_str)] + + except Exception as e: + logger.error( + f"Database error executing query: {e}\n{query}\n{params}" + ) + return [ + types.TextContent( + type="text", text=f"Error: {e}\n{query}\n{params}" + ) + ] + + elif name == "write_neo4j_cypher": + if not arguments: + raise ValueError("No arguments provided") + + query = arguments.get("query") + params = arguments.get("params", {}) + + if not _is_write_query(query): + raise ValueError("Only write queries are allowed for write-query") + + try: + async with neo4j_driver.session(database=database) as session: + raw_results = await session.execute_write(_write, query, params) + counters_json_str = json.dumps( + raw_results._summary.counters.__dict__, default=str + ) + + logger.debug(f"Write query affected {counters_json_str}") + + return [types.TextContent(type="text", text=counters_json_str)] + + except Exception as e: + logger.error( + f"Database error executing query: {e}\n{query}\n{params}" + ) + return [ + types.TextContent( + type="text", text=f"Error: {e}\n{query}\n{params}" + ) + ] + + else: + raise ValueError(f"Unknown tool: {name}") + except Exception as e: + logger.error(f"Error handling tool call: {e}") + return [types.TextContent(type="text", text=f"Error: {str(e)}")] -if __name__ == "__main__": - main() + return server diff --git a/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transport.py b/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transport.py new file mode 100644 index 0000000..1bc08e3 --- /dev/null +++ b/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transport.py @@ -0,0 +1,28 @@ +import logging +from abc import ABC, abstractmethod + + +class TransportLayer(ABC): + """Abstract base class for MCP transport layers""" + + def __init__(self): + self.logger = logging.getLogger("mcp_neo4j_cypher") + + @abstractmethod + async def run_server(self, mcp_server): + """Run the server with the given MCP server""" + pass + + @classmethod + def create_transport(cls, transport_type="sse"): + """Factory method to create transport layer based on type""" + if transport_type.lower() == "sse": + from src.transports.sse_transport import SSETransport + + return SSETransport() + elif transport_type.lower() == "stdio": + from src.transports.stdio_transport import StdIOTransport + + return StdIOTransport() + else: + raise ValueError(f"Unsupported transport type: {transport_type}") diff --git a/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transports/__init__.py b/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transports/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transports/sse_transport.py b/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transports/sse_transport.py new file mode 100644 index 0000000..85d03ee --- /dev/null +++ b/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transports/sse_transport.py @@ -0,0 +1,62 @@ +from fastapi import FastAPI +from starlette.applications import Starlette +from starlette.routing import Mount, Route + +from mcp.server import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.sse import SseServerTransport +from src.transport import TransportLayer + + +class SSETransport(TransportLayer): + """Implementation of SSE transport for MCP server""" + + def __init__(self): + super().__init__() + self.app = FastAPI( + title="Neo4j MCP Cypher Server", + description="MCP server for Neo4j Cypher queries", + ) + + async def create_app(self, mcp_server): + """Create and return a FastAPI app with the SSE transport configured""" + # Create the SSE transport + transport = SseServerTransport("/messages/") + + # Create the SSE handler + async def handle_sse(request): + async with transport.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await mcp_server.run( + streams[0], + streams[1], + InitializationOptions( + server_name="mcp-neo4j-cypher", + server_version="1.0", + capabilities=mcp_server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + + # Create the Starlette routes + routes = [ + Route("/sse", endpoint=handle_sse), + Mount("/messages", app=transport.handle_post_message), + ] + + # Create Starlette app and mount it + sse_app = Starlette(routes=routes) + self.app.mount("/", sse_app) + + return self.app + + async def run_server(self, mcp_server): + """Run the server with the given MCP server""" + self.logger.info("Starting MCP server with SSE transport") + app = await self.create_app(mcp_server) + + # In production, app will be run by uvicorn + return app diff --git a/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transports/stdio_transport.py b/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transports/stdio_transport.py new file mode 100644 index 0000000..053d286 --- /dev/null +++ b/servers/mcp-neo4j-cypher/src/mcp_neo4j_cypher/transports/stdio_transport.py @@ -0,0 +1,28 @@ +import mcp.server.stdio +from mcp.server import NotificationOptions +from mcp.server.models import InitializationOptions +from src.transport import TransportLayer + + +class StdIOTransport(TransportLayer): + """Implementation of stdio transport for MCP server""" + + async def run_server(self, mcp_server): + """Run the server with stdio transport""" + self.logger.info("Starting MCP server with stdio transport") + + # Using stdio transport + async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): + self.logger.info("MCP Neo4j Cypher server running on stdio") + await mcp_server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="mcp-neo4j-cypher", + server_version="1.0", + capabilities=mcp_server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + )