From d5127dc78cf3e03f230274c255ffc364e7b1e226 Mon Sep 17 00:00:00 2001 From: Braxton Nunnally Date: Thu, 8 May 2025 14:48:46 -0400 Subject: [PATCH] Added sse support for ne04j-memory server --- servers/mcp-neo4j-memory/Dockerfile | 12 +- servers/mcp-neo4j-memory/pyproject.toml | 5 +- .../src/mcp_neo4j_memory/__init__.py | 6 +- .../src/mcp_neo4j_memory/app.py | 79 +++++ .../src/mcp_neo4j_memory/main.py | 55 +++ .../src/mcp_neo4j_memory/server.py | 324 +++++++++++------- .../src/mcp_neo4j_memory/transport.py | 28 ++ .../transports/sse_transport.py | 62 ++++ .../transports/stdio_transport.py | 29 ++ 9 files changed, 468 insertions(+), 132 deletions(-) create mode 100644 servers/mcp-neo4j-memory/src/mcp_neo4j_memory/app.py create mode 100644 servers/mcp-neo4j-memory/src/mcp_neo4j_memory/main.py create mode 100644 servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transport.py create mode 100644 servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/sse_transport.py create mode 100644 servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/stdio_transport.py diff --git a/servers/mcp-neo4j-memory/Dockerfile b/servers/mcp-neo4j-memory/Dockerfile index 29a9b4e..1d7a4bd 100644 --- a/servers/mcp-neo4j-memory/Dockerfile +++ b/servers/mcp-neo4j-memory/Dockerfile @@ -10,7 +10,7 @@ RUN pip install --no-cache-dir hatchling COPY pyproject.toml /app/ # Install runtime dependencies -RUN pip install --no-cache-dir mcp>=0.10.0 neo4j>=5.26.0 +RUN pip install --no-cache-dir mcp>=0.10.0 neo4j>=5.26.0 fastapi uvicorn # Copy the source code COPY src/ /app/src/ @@ -22,7 +22,13 @@ RUN pip install --no-cache-dir -e . # Environment variables for Neo4j connection ENV NEO4J_URL="bolt://host.docker.internal:7687" ENV NEO4J_USERNAME="neo4j" -ENV NEO4J_PASSWORD="password" +ENV NEO4J_PASSWORD="neo4j_password" + +# Set transport type (can be "sse" or "stdio") +ENV MCP_TRANSPORT="sse" + +# Expose port for SSE transport +EXPOSE 8090 # Command to run the server using the package entry point -CMD ["sh", "-c", "mcp-neo4j-memory --db-url ${NEO4J_URL} --username ${NEO4J_USERNAME} --password ${NEO4J_PASSWORD}"] \ No newline at end of file +CMD ["python", "-m", "src.mcp_neo4j_memory.main"] \ No newline at end of file diff --git a/servers/mcp-neo4j-memory/pyproject.toml b/servers/mcp-neo4j-memory/pyproject.toml index c5ed6f7..283d90f 100644 --- a/servers/mcp-neo4j-memory/pyproject.toml +++ b/servers/mcp-neo4j-memory/pyproject.toml @@ -1,12 +1,13 @@ [project] name = "mcp-neo4j-memory" -version = "0.1.3" +version = "0.1.1" description = "MCP Neo4j Knowledge Graph Memory Server" readme = "README.md" requires-python = ">=3.10" dependencies = [ "mcp>=0.10.0", "neo4j>=5.26.0", + "fastapi" ] [build-system] @@ -21,7 +22,7 @@ dev-dependencies = [ ] [project.scripts] -mcp-neo4j-memory = "mcp_neo4j_memory:main" +mcp-neo4j-memory = "mcp_neo4j_memory.main:main" [tool.pytest.ini_options] pythonpath = [ diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/__init__.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/__init__.py index 123507a..39d6ae4 100644 --- a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/__init__.py +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/__init__.py @@ -8,13 +8,13 @@ def main(): """Main entry point for the package.""" parser = argparse.ArgumentParser(description='Neo4j Cypher MCP Server') parser.add_argument('--db-url', - default=os.getenv("NEO4J_URL", "bolt://localhost:7687"), + default="bolt://localhost:7687", help='Neo4j connection URL') parser.add_argument('--username', - default=os.getenv("NEO4J_USERNAME", "neo4j"), + default="neo4j", help='Neo4j username') parser.add_argument('--password', - default=os.getenv("NEO4J_PASSWORD", "password"), + default="password", help='Neo4j password') args = parser.parse_args() diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/app.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/app.py new file mode 100644 index 0000000..ff2cf04 --- /dev/null +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/app.py @@ -0,0 +1,79 @@ +import logging +import os + +from fastapi import FastAPI + +from src.mcp_neo4j_memory.server import create_mcp_server +from src.mcp_neo4j_memory.transport import TransportLayer + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("mcp_neo4j_memory") + + +async def start_server(transport_type="sse"): + """Start the MCP server with the specified transport""" + # Get environment variables for Neo4j connection + neo4j_uri = os.environ.get( + "NEO4J_URI", os.environ.get("NEO4J_URL", "bolt://host.docker.internal:7687") + ) + neo4j_user = os.environ.get("NEO4J_USERNAME", os.environ.get("NEO4J_USER", "neo4j")) + neo4j_password = os.environ.get("NEO4J_PASSWORD", "neo4j_password") + + try: + # Create the MCP server + logger.info(f"Creating MCP server with transport type: {transport_type}") + mcp_server = await create_mcp_server(neo4j_uri, neo4j_user, neo4j_password) + + # Create transport and run server + transport = TransportLayer.create_transport(transport_type) + return await transport.run_server(mcp_server) + except Exception as e: + logger.error(f"Failed to start server: {e}") + raise e + + +# For FastAPI app +app = FastAPI() + + +@app.on_event("startup") +async def startup_event(): + global app + transport_type = os.environ.get("MCP_TRANSPORT", "sse") + if transport_type.lower() == "sse": + # If using SSE, mount the SSE app to this app + sse_app = await start_server("sse") + + # Add health check endpoint + @app.get("/health") + def health_check(): + return {"status": "ok"} + + # Add root endpoint + @app.get("/") + def read_root(): + return { + "message": "Neo4j MCP Memory Server", + "transport": "SSE", + "sse_endpoint": "/sse", + } + + # Mount all routes from the SSE app + for route in sse_app.routes: + app.routes.append(route) + else: + # For stdio, just inform that this app should not be used + @app.get("/") + def read_root(): + return { + "error": "This server is configured to use stdio transport, not HTTP/SSE", + "message": "Please run this server directly from the command line", + } + + @app.get("/health") + def health_check(): + return { + "status": "error", + "message": "Server is configured for stdio transport", + } diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/main.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/main.py new file mode 100644 index 0000000..2f56648 --- /dev/null +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/main.py @@ -0,0 +1,55 @@ +import argparse +import asyncio +import os + + +async def start_stdio_server(): + """Start the server with stdio transport""" + from src.mcp_neo4j_memory.app import start_server + + await start_server("stdio") + + +def main(): + """Command-line entry point for running the MCP server""" + parser = argparse.ArgumentParser(description="MCP Neo4j Memory Server") + parser.add_argument( + "--transport", + choices=["sse", "stdio"], + default="sse", + help="Transport type (sse or stdio)", + ) + parser.add_argument("--db-url", dest="neo4j_uri", help="Neo4j database URL") + parser.add_argument("--username", dest="neo4j_username", help="Neo4j username") + parser.add_argument("--password", dest="neo4j_password", help="Neo4j password") + parser.add_argument( + "--port", + type=int, + default=8090, + help="Port for HTTP server (when using SSE transport)", + ) + + args = parser.parse_args() + + # Set environment variables from arguments + if args.neo4j_uri: + os.environ["NEO4J_URI"] = args.neo4j_uri + if args.neo4j_username: + os.environ["NEO4J_USERNAME"] = args.neo4j_username + if args.neo4j_password: + os.environ["NEO4J_PASSWORD"] = args.neo4j_password + + os.environ["MCP_TRANSPORT"] = args.transport + + if args.transport == "stdio": + # Run with stdio transport + asyncio.run(start_stdio_server()) + else: + # Run with SSE transport via FastAPI/Uvicorn + import uvicorn + + uvicorn.run("src.mcp_neo4j_memory.app:app", host="0.0.0.0", port=args.port, reload=False) + + +if __name__ == "__main__": + main() diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py index 947a263..78dd483 100644 --- a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py @@ -1,45 +1,47 @@ -import os -import logging import json -from typing import Any, Dict, List, Optional -from contextlib import asynccontextmanager +import logging +from typing import Any, Dict, List import neo4j from neo4j import GraphDatabase from pydantic import BaseModel import mcp.types as types -from mcp.server import NotificationOptions, Server -from mcp.server.models import InitializationOptions -import mcp.server.stdio +from mcp.server import Server # Set up logging -logger = logging.getLogger('mcp_neo4j_memory') +logger = logging.getLogger("mcp_neo4j_memory") logger.setLevel(logging.INFO) + # Models for our knowledge graph class Entity(BaseModel): name: str type: str observations: List[str] + class Relation(BaseModel): source: str target: str relationType: str + class KnowledgeGraph(BaseModel): entities: List[Entity] relations: List[Relation] + class ObservationAddition(BaseModel): entityName: str contents: List[str] + class ObservationDeletion(BaseModel): entityName: str observations: List[str] + class Neo4jMemory: def __init__(self, neo4j_driver): self.neo4j_driver = neo4j_driver @@ -47,7 +49,7 @@ def __init__(self, neo4j_driver): def create_fulltext_index(self): try: - # TODO , + # TODO , query = """ CREATE FULLTEXT INDEX search IF NOT EXISTS FOR (m:Memory) ON EACH [m.name, m.type, m.observations]; """ @@ -64,47 +66,49 @@ async def load_graph(self, filter_query="*"): CALL db.index.fulltext.queryNodes('search', $filter) yield node as entity, score OPTIONAL MATCH (entity)-[r]-(other) RETURN collect(distinct { - name: entity.name, - type: entity.type, + name: entity.name, + type: entity.type, observations: entity.observations }) as nodes, collect(distinct { - source: startNode(r).name, - target: endNode(r).name, + source: startNode(r).name, + target: endNode(r).name, relationType: type(r) }) as relations """ - + result = self.neo4j_driver.execute_query(query, {"filter": filter_query}) - + if not result.records: return KnowledgeGraph(entities=[], relations=[]) - + record = result.records[0] - nodes = record.get('nodes') - rels = record.get('relations') - + nodes = record.get("nodes") + rels = record.get("relations") + entities = [ Entity( - name=node.get('name'), - type=node.get('type'), - observations=node.get('observations', []) + name=node.get("name"), + type=node.get("type"), + observations=node.get("observations", []), ) - for node in nodes if node.get('name') + for node in nodes + if node.get("name") ] - + relations = [ Relation( - source=rel.get('source'), - target=rel.get('target'), - relationType=rel.get('relationType') + source=rel.get("source"), + target=rel.get("target"), + relationType=rel.get("relationType"), ) - for rel in rels if rel.get('source') and rel.get('target') and rel.get('relationType') + for rel in rels + if rel.get("source") and rel.get("target") and rel.get("relationType") ] - + logger.debug(f"Loaded entities: {entities}") logger.debug(f"Loaded relations: {relations}") - + return KnowledgeGraph(entities=entities, relations=relations) async def create_entities(self, entities: List[Entity]) -> List[Entity]: @@ -114,7 +118,7 @@ async def create_entities(self, entities: List[Entity]) -> List[Entity]: SET e += entity {.type, .observations} SET e:$(entity.type) """ - + entities_data = [entity.model_dump() for entity in entities] self.neo4j_driver.execute_query(query, {"entities": entities_data}) return entities @@ -128,29 +132,32 @@ async def create_relations(self, relations: List[Relation]) -> List[Relation]: AND to.name = relation.target MERGE (from)-[r:$(relation.relationType)]->(to) """ - + self.neo4j_driver.execute_query( - query, - {"relations": [relation.model_dump() for relation in relations]} + query, {"relations": [relation.model_dump() for relation in relations]} ) - + return relations - async def add_observations(self, observations: List[ObservationAddition]) -> List[Dict[str, Any]]: + async def add_observations( + self, observations: List[ObservationAddition] + ) -> List[Dict[str, Any]]: query = """ - UNWIND $observations as obs + UNWIND $observations as obs MATCH (e:Memory { name: obs.entityName }) WITH e, [o in obs.contents WHERE NOT o IN e.observations] as new SET e.observations = coalesce(e.observations,[]) + new RETURN e.name as name, new """ - + result = self.neo4j_driver.execute_query( - query, - {"observations": [obs.model_dump() for obs in observations]} + query, {"observations": [obs.model_dump() for obs in observations]} ) - results = [{"entityName": record.get("name"), "addedObservations": record.get("new")} for record in result.records] + results = [ + {"entityName": record.get("name"), "addedObservations": record.get("new")} + for record in result.records + ] return results async def delete_entities(self, entity_names: List[str]) -> None: @@ -159,20 +166,17 @@ async def delete_entities(self, entity_names: List[str]) -> None: MATCH (e:Memory { name: name }) DETACH DELETE e """ - + self.neo4j_driver.execute_query(query, {"entities": entity_names}) async def delete_observations(self, deletions: List[ObservationDeletion]) -> None: query = """ - UNWIND $deletions as d + UNWIND $deletions as d MATCH (e:Memory { name: d.entityName }) SET e.observations = [o in coalesce(e.observations,[]) WHERE NOT o IN d.observations] """ self.neo4j_driver.execute_query( - query, - { - "deletions": [deletion.model_dump() for deletion in deletions] - } + query, {"deletions": [deletion.model_dump() for deletion in deletions]} ) async def delete_relations(self, relations: List[Relation]) -> None: @@ -184,8 +188,7 @@ async def delete_relations(self, relations: List[Relation]) -> None: DELETE r """ self.neo4j_driver.execute_query( - query, - {"relations": [relation.model_dump() for relation in relations]} + query, {"relations": [relation.model_dump() for relation in relations]} ) async def read_graph(self) -> KnowledgeGraph: @@ -197,26 +200,24 @@ async def search_nodes(self, query: str) -> KnowledgeGraph: async def find_nodes(self, names: List[str]) -> KnowledgeGraph: return await self.load_graph("name: (" + " ".join(names) + ")") -async def main(neo4j_uri: str, neo4j_user: str, neo4j_password: str): - logger.info(f"Connecting to neo4j MCP Server with DB URL: {neo4j_uri}") + +async def create_mcp_server(neo4j_uri, neo4j_user, neo4j_password): + logger.info(f"Connecting to Neo4j at {neo4j_uri}") # Connect to Neo4j - neo4j_driver = GraphDatabase.driver( - neo4j_uri, - auth=(neo4j_user, neo4j_password) - ) - + neo4j_driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password)) + # Verify connection try: neo4j_driver.verify_connectivity() logger.info(f"Connected to Neo4j at {neo4j_uri}") except Exception as e: logger.error(f"Failed to connect to Neo4j: {e}") - exit(1) + raise e # Initialize memory memory = Neo4jMemory(neo4j_driver) - + # Create MCP server server = Server("mcp-neo4j-memory") @@ -235,12 +236,18 @@ async def handle_list_tools() -> List[types.Tool]: "items": { "type": "object", "properties": { - "name": {"type": "string", "description": "The name of the entity"}, - "type": {"type": "string", "description": "The type of the entity"}, + "name": { + "type": "string", + "description": "The name of the entity", + }, + "type": { + "type": "string", + "description": "The type of the entity", + }, "observations": { "type": "array", "items": {"type": "string"}, - "description": "An array of observation contents associated with the entity" + "description": "An array of observation contents associated with the entity", }, }, "required": ["name", "type", "observations"], @@ -261,9 +268,18 @@ async def handle_list_tools() -> List[types.Tool]: "items": { "type": "object", "properties": { - "source": {"type": "string", "description": "The name of the entity where the relation starts"}, - "target": {"type": "string", "description": "The name of the entity where the relation ends"}, - "relationType": {"type": "string", "description": "The type of the relation"}, + "source": { + "type": "string", + "description": "The name of the entity where the relation starts", + }, + "target": { + "type": "string", + "description": "The name of the entity where the relation ends", + }, + "relationType": { + "type": "string", + "description": "The type of the relation", + }, }, "required": ["source", "target", "relationType"], }, @@ -283,11 +299,14 @@ async def handle_list_tools() -> List[types.Tool]: "items": { "type": "object", "properties": { - "entityName": {"type": "string", "description": "The name of the entity to add the observations to"}, + "entityName": { + "type": "string", + "description": "The name of the entity to add the observations to", + }, "contents": { "type": "array", "items": {"type": "string"}, - "description": "An array of observation contents to add" + "description": "An array of observation contents to add", }, }, "required": ["entityName", "contents"], @@ -306,7 +325,7 @@ async def handle_list_tools() -> List[types.Tool]: "entityNames": { "type": "array", "items": {"type": "string"}, - "description": "An array of entity names to delete" + "description": "An array of entity names to delete", }, }, "required": ["entityNames"], @@ -323,11 +342,14 @@ async def handle_list_tools() -> List[types.Tool]: "items": { "type": "object", "properties": { - "entityName": {"type": "string", "description": "The name of the entity containing the observations"}, + "entityName": { + "type": "string", + "description": "The name of the entity containing the observations", + }, "observations": { "type": "array", "items": {"type": "string"}, - "description": "An array of observations to delete" + "description": "An array of observations to delete", }, }, "required": ["entityName", "observations"], @@ -348,13 +370,22 @@ async def handle_list_tools() -> List[types.Tool]: "items": { "type": "object", "properties": { - "source": {"type": "string", "description": "The name of the entity where the relation starts"}, - "target": {"type": "string", "description": "The name of the entity where the relation ends"}, - "relationType": {"type": "string", "description": "The type of the relation"}, + "source": { + "type": "string", + "description": "The name of the entity where the relation starts", + }, + "target": { + "type": "string", + "description": "The name of the entity where the relation ends", + }, + "relationType": { + "type": "string", + "description": "The type of the relation", + }, }, "required": ["source", "target", "relationType"], }, - "description": "An array of relations to delete" + "description": "An array of relations to delete", }, }, "required": ["relations"], @@ -374,14 +405,17 @@ async def handle_list_tools() -> List[types.Tool]: inputSchema={ "type": "object", "properties": { - "query": {"type": "string", "description": "The search query to match against entity names, types, and observation content"}, + "query": { + "type": "string", + "description": "The search query to match against entity names, types, and observation content", + }, }, "required": ["query"], }, ), types.Tool( name="find_nodes", - description="Find specific nodes in the knowledge graph by their names", + description="Open specific nodes in the knowledge graph by their names", inputSchema={ "type": "object", "properties": { @@ -395,18 +429,17 @@ async def handle_list_tools() -> List[types.Tool]: }, ), types.Tool( - name="open_nodes", - description="Open specific nodes in the knowledge graph by their names", + name="entity_exists", + description="Check if an entity exists in the knowledge graph by name", inputSchema={ "type": "object", "properties": { - "names": { - "type": "array", - "items": {"type": "string"}, - "description": "An array of entity names to retrieve", + "name": { + "type": "string", + "description": "The name of the entity to check", }, }, - "required": ["names"], + "required": ["name"], }, ), ] @@ -416,69 +449,112 @@ async def handle_call_tool( name: str, arguments: Dict[str, Any] | None ) -> List[types.TextContent | types.ImageContent]: try: - if name == "read_graph": - result = await memory.read_graph() - return [types.TextContent(type="text", text=json.dumps(result.model_dump(), indent=2))] - if not arguments: raise ValueError(f"No arguments provided for tool: {name}") if name == "create_entities": - entities = [Entity(**entity) for entity in arguments.get("entities", [])] + entities = [ + Entity(**entity) for entity in arguments.get("entities", []) + ] result = await memory.create_entities(entities) - return [types.TextContent(type="text", text=json.dumps([e.model_dump() for e in result], indent=2))] - + return [ + types.TextContent( + type="text", + text=json.dumps([e.model_dump() for e in result], indent=2), + ) + ] + elif name == "create_relations": - relations = [Relation(**relation) for relation in arguments.get("relations", [])] + relations = [ + Relation(**relation) for relation in arguments.get("relations", []) + ] result = await memory.create_relations(relations) - return [types.TextContent(type="text", text=json.dumps([r.model_dump() for r in result], indent=2))] - + return [ + types.TextContent( + type="text", + text=json.dumps([r.model_dump() for r in result], indent=2), + ) + ] + elif name == "add_observations": - observations = [ObservationAddition(**obs) for obs in arguments.get("observations", [])] + observations = [ + ObservationAddition(**obs) + for obs in arguments.get("observations", []) + ] result = await memory.add_observations(observations) - return [types.TextContent(type="text", text=json.dumps(result, indent=2))] - + return [ + types.TextContent(type="text", text=json.dumps(result, indent=2)) + ] + elif name == "delete_entities": await memory.delete_entities(arguments.get("entityNames", [])) - return [types.TextContent(type="text", text="Entities deleted successfully")] - + return [ + types.TextContent(type="text", text="Entities deleted successfully") + ] + elif name == "delete_observations": - deletions = [ObservationDeletion(**deletion) for deletion in arguments.get("deletions", [])] + deletions = [ + ObservationDeletion(**deletion) + for deletion in arguments.get("deletions", []) + ] await memory.delete_observations(deletions) - return [types.TextContent(type="text", text="Observations deleted successfully")] - + return [ + types.TextContent( + type="text", text="Observations deleted successfully" + ) + ] + elif name == "delete_relations": - relations = [Relation(**relation) for relation in arguments.get("relations", [])] + relations = [ + Relation(**relation) for relation in arguments.get("relations", []) + ] await memory.delete_relations(relations) - return [types.TextContent(type="text", text="Relations deleted successfully")] - + return [ + types.TextContent( + type="text", text="Relations deleted successfully" + ) + ] + + elif name == "read_graph": + result = await memory.read_graph() + return [ + types.TextContent( + type="text", text=json.dumps(result.model_dump(), indent=2) + ) + ] + elif name == "search_nodes": result = await memory.search_nodes(arguments.get("query", "")) - return [types.TextContent(type="text", text=json.dumps(result.model_dump(), indent=2))] - - elif name == "find_nodes" or name == "open_nodes": + return [ + types.TextContent( + type="text", text=json.dumps(result.model_dump(), indent=2) + ) + ] + + elif name == "find_nodes": result = await memory.find_nodes(arguments.get("names", [])) - return [types.TextContent(type="text", text=json.dumps(result.model_dump(), indent=2))] - + return [ + types.TextContent( + type="text", text=json.dumps(result.model_dump(), indent=2) + ) + ] + elif name == "entity_exists": + entity_name = arguments.get("name", "") + query = """ + MATCH (e:Memory {name: $name}) + RETURN COUNT(e) > 0 AS exists + """ + + result = neo4j_driver.execute_query(query, {"name": entity_name}) + exists = result.records[0].get("exists") if result.records else False + return [ + types.TextContent(type="text", text=json.dumps({"exists": exists})) + ] else: raise ValueError(f"Unknown tool: {name}") - + except Exception as e: logger.error(f"Error handling tool call: {e}") return [types.TextContent(type="text", text=f"Error: {str(e)}")] - # Start the server - async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): - logger.info("MCP Knowledge Graph Memory using Neo4j running on stdio") - await server.run( - read_stream, - write_stream, - InitializationOptions( - server_name="mcp-neo4j-memory", - server_version="1.1", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) + return server diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transport.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transport.py new file mode 100644 index 0000000..9f68df8 --- /dev/null +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transport.py @@ -0,0 +1,28 @@ +import logging +from abc import ABC, abstractmethod + + +class TransportLayer(ABC): + """Abstract base class for MCP transport layers""" + + def __init__(self): + self.logger = logging.getLogger("mcp_neo4j_memory") + + @abstractmethod + async def run_server(self, mcp_server): + """Run the server with the given MCP server""" + pass + + @classmethod + def create_transport(cls, transport_type="sse"): + """Factory method to create transport layer based on type""" + if transport_type.lower() == "sse": + from src.mcp_neo4j_memory.transports.sse_transport import SSETransport + + return SSETransport() + elif transport_type.lower() == "stdio": + from src.mcp_neo4j_memory.transports.stdio_transport import StdIOTransport + + return StdIOTransport() + else: + raise ValueError(f"Unsupported transport type: {transport_type}") diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/sse_transport.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/sse_transport.py new file mode 100644 index 0000000..8f2f9c4 --- /dev/null +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/sse_transport.py @@ -0,0 +1,62 @@ +from fastapi import FastAPI +from src.mcp_neo4j_memory.transport import TransportLayer +from starlette.applications import Starlette +from starlette.routing import Mount, Route + +from mcp.server import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.sse import SseServerTransport + + +class SSETransport(TransportLayer): + """Implementation of SSE transport for MCP server""" + + def __init__(self): + super().__init__() + self.app = FastAPI( + title="Neo4j MCP Memory Server", + description="MCP server for Neo4j knowledge graph memory", + ) + + async def create_app(self, mcp_server): + """Create and return a FastAPI app with the SSE transport configured""" + # Create the SSE transport + transport = SseServerTransport("/messages/") + + # Create the SSE handler + async def handle_sse(request): + async with transport.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await mcp_server.run( + streams[0], + streams[1], + InitializationOptions( + server_name="mcp-neo4j-memory", + server_version="1.1", + capabilities=mcp_server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + + # Create the Starlette routes + routes = [ + Route("/sse", endpoint=handle_sse), + Mount("/messages", app=transport.handle_post_message), + ] + + # Create Starlette app and mount it + sse_app = Starlette(routes=routes) + self.app.mount("/", sse_app) + + return self.app + + async def run_server(self, mcp_server): + """Run the server with the given MCP server""" + self.logger.info("Starting MCP server with SSE transport") + app = await self.create_app(mcp_server) + + # In production, app will be run by uvicorn + return app diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/stdio_transport.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/stdio_transport.py new file mode 100644 index 0000000..34a1810 --- /dev/null +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/transports/stdio_transport.py @@ -0,0 +1,29 @@ +from mcp_neo4j_memory.transport import TransportLayer + +import mcp.server.stdio +from mcp.server import NotificationOptions +from mcp.server.models import InitializationOptions + + +class StdIOTransport(TransportLayer): + """Implementation of stdio transport for MCP server""" + + async def run_server(self, mcp_server): + """Run the server with stdio transport""" + self.logger.info("Starting MCP server with stdio transport") + + # Using the original implementation from server.py + async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): + self.logger.info("MCP Knowledge Graph Memory using Neo4j running on stdio") + await mcp_server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="mcp-neo4j-memory", + server_version="1.1", + capabilities=mcp_server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + )