diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py index b397849..f59437a 100644 --- a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py @@ -47,7 +47,7 @@ def __init__(self, neo4j_driver): def create_fulltext_index(self): try: - # TODO , + # TODO , query = """ CREATE FULLTEXT INDEX search IF NOT EXISTS FOR (m:Memory) ON EACH [m.name, m.type, m.observations]; """ @@ -64,26 +64,26 @@ async def load_graph(self, filter_query="*"): CALL db.index.fulltext.queryNodes('search', $filter) yield node as entity, score OPTIONAL MATCH (entity)-[r]-(other) RETURN collect(distinct { - name: entity.name, - type: entity.type, + name: entity.name, + type: entity.type, observations: entity.observations }) as nodes, collect(distinct { - source: startNode(r).name, - target: endNode(r).name, + source: startNode(r).name, + target: endNode(r).name, relationType: type(r) }) as relations """ - + result = self.neo4j_driver.execute_query(query, {"filter": filter_query}) - + if not result.records: return KnowledgeGraph(entities=[], relations=[]) - + record = result.records[0] nodes = record.get('nodes') rels = record.get('relations') - + entities = [ Entity( name=node.get('name'), @@ -92,7 +92,7 @@ async def load_graph(self, filter_query="*"): ) for node in nodes if node.get('name') ] - + relations = [ Relation( source=rel.get('source'), @@ -101,10 +101,10 @@ async def load_graph(self, filter_query="*"): ) for rel in rels if rel.get('source') and rel.get('target') and rel.get('relationType') ] - + logger.debug(f"Loaded entities: {entities}") logger.debug(f"Loaded relations: {relations}") - + return KnowledgeGraph(entities=entities, relations=relations) async def create_entities(self, entities: List[Entity]) -> List[Entity]: @@ -114,39 +114,37 @@ async def create_entities(self, entities: List[Entity]) -> List[Entity]: SET e += entity {.type, .observations} SET e:$(entity.type) """ - + entities_data = [entity.model_dump() for entity in entities] self.neo4j_driver.execute_query(query, {"entities": entities_data}) return entities async def create_relations(self, relations: List[Relation]) -> List[Relation]: - for relation in relations: - query = """ - UNWIND $relations as relation - MATCH (from:Memory),(to:Memory) - WHERE from.name = relation.source - AND to.name = relation.target - MERGE (from)-[r:$(relation.relationType)]->(to) - """ - - self.neo4j_driver.execute_query( - query, - {"relations": [relation.model_dump() for relation in relations]} - ) - + query = """ + UNWIND $relations AS relation + MATCH (from:Memory { name: relation.source }) + MATCH (to:Memory { name: relation.target }) + CALL apoc.merge.relationship(from, relation.relationType, {}, {}, to) YIELD rel + RETURN rel + """ + self.neo4j_driver.execute_query( + query, + {"relations": [rel.model_dump() for rel in relations]} + ) + return relations async def add_observations(self, observations: List[ObservationAddition]) -> List[Dict[str, Any]]: query = """ - UNWIND $observations as obs + UNWIND $observations as obs MATCH (e:Memory { name: obs.entityName }) WITH e, [o in obs.contents WHERE NOT o IN e.observations] as new SET e.observations = coalesce(e.observations,[]) + new RETURN e.name as name, new """ - + result = self.neo4j_driver.execute_query( - query, + query, {"observations": [obs.model_dump() for obs in observations]} ) @@ -159,17 +157,17 @@ async def delete_entities(self, entity_names: List[str]) -> None: MATCH (e:Memory { name: name }) DETACH DELETE e """ - + self.neo4j_driver.execute_query(query, {"entities": entity_names}) async def delete_observations(self, deletions: List[ObservationDeletion]) -> None: query = """ - UNWIND $deletions as d + UNWIND $deletions as d MATCH (e:Memory { name: d.entityName }) SET e.observations = [o in coalesce(e.observations,[]) WHERE NOT o IN d.observations] """ self.neo4j_driver.execute_query( - query, + query, { "deletions": [deletion.model_dump() for deletion in deletions] } @@ -184,7 +182,7 @@ async def delete_relations(self, relations: List[Relation]) -> None: DELETE r """ self.neo4j_driver.execute_query( - query, + query, {"relations": [relation.model_dump() for relation in relations]} ) @@ -205,7 +203,7 @@ async def main(neo4j_uri: str, neo4j_user: str, neo4j_password: str): neo4j_uri, auth=(neo4j_user, neo4j_password) ) - + # Verify connection try: neo4j_driver.verify_connectivity() @@ -216,7 +214,7 @@ async def main(neo4j_uri: str, neo4j_user: str, neo4j_password: str): # Initialize memory memory = Neo4jMemory(neo4j_driver) - + # Create MCP server server = Server("mcp-neo4j-memory") @@ -399,55 +397,63 @@ async def handle_list_tools() -> List[types.Tool]: @server.call_tool() async def handle_call_tool( name: str, arguments: Dict[str, Any] | None - ) -> List[types.TextContent | types.ImageContent]: + ) -> List[types.TextContent]: try: - if not arguments: - raise ValueError(f"No arguments provided for tool: {name}") - - if name == "create_entities": - entities = [Entity(**entity) for entity in arguments.get("entities", [])] - result = await memory.create_entities(entities) - return [types.TextContent(type="text", text=json.dumps([e.model_dump() for e in result], indent=2))] - - elif name == "create_relations": - relations = [Relation(**relation) for relation in arguments.get("relations", [])] - result = await memory.create_relations(relations) - return [types.TextContent(type="text", text=json.dumps([r.model_dump() for r in result], indent=2))] - - elif name == "add_observations": - observations = [ObservationAddition(**obs) for obs in arguments.get("observations", [])] - result = await memory.add_observations(observations) - return [types.TextContent(type="text", text=json.dumps(result, indent=2))] - - elif name == "delete_entities": - await memory.delete_entities(arguments.get("entityNames", [])) - return [types.TextContent(type="text", text="Entities deleted successfully")] - - elif name == "delete_observations": - deletions = [ObservationDeletion(**deletion) for deletion in arguments.get("deletions", [])] - await memory.delete_observations(deletions) - return [types.TextContent(type="text", text="Observations deleted successfully")] - - elif name == "delete_relations": - relations = [Relation(**relation) for relation in arguments.get("relations", [])] - await memory.delete_relations(relations) - return [types.TextContent(type="text", text="Relations deleted successfully")] - - elif name == "read_graph": - result = await memory.read_graph() - return [types.TextContent(type="text", text=json.dumps(result.model_dump(), indent=2))] - - elif name == "search_nodes": - result = await memory.search_nodes(arguments.get("query", "")) - return [types.TextContent(type="text", text=json.dumps(result.model_dump(), indent=2))] - - elif name == "find_nodes": - result = await memory.find_nodes(arguments.get("names", [])) - return [types.TextContent(type="text", text=json.dumps(result.model_dump(), indent=2))] - - else: + # Prepare args for handler + args = arguments or {} + + # Dispatch table: tool name -> (method, arg_key, model_cls, formatter) + dispatch = { + "create_entities": ( + memory.create_entities, "entities", Entity, + lambda res: [types.TextContent(type="text", text=json.dumps([e.model_dump() for e in res], indent=2))] + ), + "create_relations": ( + memory.create_relations, "relations", Relation, + lambda res: [types.TextContent(type="text", text=json.dumps([r.model_dump() for r in res], indent=2))] + ), + "add_observations": ( + memory.add_observations, "observations", ObservationAddition, + lambda res: [types.TextContent(type="text", text=json.dumps(res, indent=2))] + ), + "delete_entities": ( + memory.delete_entities, "entityNames", None, + lambda _: [types.TextContent(type="text", text="Entities deleted successfully")] + ), + "delete_observations": ( + memory.delete_observations, "deletions", ObservationDeletion, + lambda _: [types.TextContent(type="text", text="Observations deleted successfully")] + ), + "delete_relations": ( + memory.delete_relations, "relations", Relation, + lambda _: [types.TextContent(type="text", text="Relations deleted successfully")] + ), + "read_graph": ( + memory.read_graph, None, None, + lambda res: [types.TextContent(type="text", text=json.dumps(res.model_dump(), indent=2))] + ), + "search_nodes": ( + memory.search_nodes, "query", None, + lambda res: [types.TextContent(type="text", text=json.dumps(res.model_dump(), indent=2))] + ), + "find_nodes": ( + memory.find_nodes, "names", None, + lambda res: [types.TextContent(type="text", text=json.dumps(res.model_dump(), indent=2))] + ), + } + if name not in dispatch: raise ValueError(f"Unknown tool: {name}") - + method, arg_key, model_cls, formatter = dispatch[name] + # Prepare parameters for memory method + if arg_key: + raw = args.get(arg_key, []) + param = [model_cls(**item) for item in raw] if model_cls else raw + result = await method(param) + else: + result = await method() + # Format and return + return formatter(result) + except Exception as e: logger.error(f"Error handling tool call: {e}") return [types.TextContent(type="text", text=f"Error: {str(e)}")] diff --git a/servers/mcp-neo4j-memory/tests/test_neo4j_memory_integration.py b/servers/mcp-neo4j-memory/tests/test_neo4j_memory_integration.py index 2ac9b6e..4b218c8 100644 --- a/servers/mcp-neo4j-memory/tests/test_neo4j_memory_integration.py +++ b/servers/mcp-neo4j-memory/tests/test_neo4j_memory_integration.py @@ -10,20 +10,20 @@ def neo4j_driver(): uri = os.environ.get("NEO4J_URI", "neo4j://localhost:7687") user = os.environ.get("NEO4J_USERNAME", "neo4j") password = os.environ.get("NEO4J_PASSWORD", "password") - + driver = GraphDatabase.driver(uri, auth=(user, password)) - + # Verify connection try: driver.verify_connectivity() except Exception as e: pytest.skip(f"Could not connect to Neo4j: {e}") - + yield driver - + # Clean up test data after tests driver.execute_query("MATCH (n:Memory) DETACH DELETE n") - + driver.close() @pytest.fixture(scope="function") @@ -41,13 +41,13 @@ async def test_create_and_read_entities(memory): # Create entities in the graph created_entities = await memory.create_entities(test_entities) assert len(created_entities) == 2 - + # Read the graph graph = await memory.read_graph() - + # Verify entities were created assert len(graph.entities) == 2 - + # Check if entities have correct data entities_by_name = {entity.name: entity for entity in graph.entities} assert "Alice" in entities_by_name @@ -64,19 +64,19 @@ async def test_create_and_read_relations(memory): Entity(name="Bob", type="Person", observations=[]) ] await memory.create_entities(test_entities) - + # Create test relation test_relations = [ Relation(source="Alice", target="Bob", relationType="KNOWS") ] - + # Create relation in the graph created_relations = await memory.create_relations(test_relations) assert len(created_relations) == 1 - + # Read the graph graph = await memory.read_graph() - + # Verify relation was created assert len(graph.relations) == 1 relation = graph.relations[0] @@ -89,22 +89,22 @@ async def test_add_observations(memory): # Create test entity test_entity = Entity(name="Charlie", type="Person", observations=["Initial observation"]) await memory.create_entities([test_entity]) - + # Add observations observation_additions = [ ObservationAddition(entityName="Charlie", contents=["New observation 1", "New observation 2"]) ] - + result = await memory.add_observations(observation_additions) assert len(result) == 1 - + # Read the graph graph = await memory.read_graph() - + # Find Charlie charlie = next((e for e in graph.entities if e.name == "Charlie"), None) assert charlie is not None - + # Verify observations were added assert "Initial observation" in charlie.observations assert "New observation 1" in charlie.observations @@ -114,26 +114,26 @@ async def test_add_observations(memory): async def test_delete_observations(memory): # Create test entity with observations test_entity = Entity( - name="Dave", - type="Person", + name="Dave", + type="Person", observations=["Observation 1", "Observation 2", "Observation 3"] ) await memory.create_entities([test_entity]) - + # Delete specific observations observation_deletions = [ ObservationDeletion(entityName="Dave", observations=["Observation 2"]) ] - + await memory.delete_observations(observation_deletions) - + # Read the graph graph = await memory.read_graph() - + # Find Dave dave = next((e for e in graph.entities if e.name == "Dave"), None) assert dave is not None - + # Verify observation was deleted assert "Observation 1" in dave.observations assert "Observation 2" not in dave.observations @@ -147,13 +147,13 @@ async def test_delete_entities(memory): Entity(name="Frank", type="Person", observations=[]) ] await memory.create_entities(test_entities) - + # Delete one entity await memory.delete_entities(["Eve"]) - + # Read the graph graph = await memory.read_graph() - + # Verify Eve was deleted but Frank remains entity_names = [e.name for e in graph.entities] assert "Eve" not in entity_names @@ -167,23 +167,23 @@ async def test_delete_relations(memory): Entity(name="Hank", type="Person", observations=[]) ] await memory.create_entities(test_entities) - + # Create test relations test_relations = [ Relation(source="Grace", target="Hank", relationType="KNOWS"), Relation(source="Grace", target="Hank", relationType="WORKS_WITH") ] await memory.create_relations(test_relations) - + # Delete one relation relations_to_delete = [ Relation(source="Grace", target="Hank", relationType="KNOWS") ] await memory.delete_relations(relations_to_delete) - + # Read the graph graph = await memory.read_graph() - + # Verify only the WORKS_WITH relation remains assert len(graph.relations) == 1 assert graph.relations[0].relationType == "WORKS_WITH" @@ -197,10 +197,10 @@ async def test_search_nodes(memory): Entity(name="Coffee", type="Beverage", observations=["Hot drink"]) ] await memory.create_entities(test_entities) - + # Search for coffee-related nodes result = await memory.search_nodes("coffee") - + # Verify search results entity_names = [e.name for e in result.entities] assert "Ian" in entity_names @@ -216,12 +216,12 @@ async def test_find_nodes(memory): Entity(name="Mike", type="Person", observations=[]) ] await memory.create_entities(test_entities) - + # Open specific nodes result = await memory.find_nodes(["Kevin", "Laura"]) - + # Verify only requested nodes are returned entity_names = [e.name for e in result.entities] assert "Kevin" in entity_names assert "Laura" in entity_names - assert "Mike" not in entity_names \ No newline at end of file + assert "Mike" not in entity_names \ No newline at end of file