diff --git a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py index b397849..dd0038c 100644 --- a/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py +++ b/servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py @@ -121,18 +121,23 @@ async def create_entities(self, entities: List[Entity]) -> List[Entity]: async def create_relations(self, relations: List[Relation]) -> List[Relation]: for relation in relations: - query = """ - UNWIND $relations as relation - MATCH (from:Memory),(to:Memory) - WHERE from.name = relation.source - AND to.name = relation.target - MERGE (from)-[r:$(relation.relationType)]->(to) + # Safely escape the relation type + rel_type = self._safe_relation_type(relation.relationType) + + # Use escaped relation type directly in query + query = f""" + MATCH (from:Memory), (to:Memory) + WHERE from.name = $source + AND to.name = $target + MERGE (from)-[r:{rel_type}]->(to) """ - self.neo4j_driver.execute_query( - query, - {"relations": [relation.model_dump() for relation in relations]} - ) + params = { + "source": relation.source, + "target": relation.target + } + + self.neo4j_driver.execute_query(query, params) return relations @@ -176,17 +181,24 @@ async def delete_observations(self, deletions: List[ObservationDeletion]) -> Non ) async def delete_relations(self, relations: List[Relation]) -> None: - query = """ - UNWIND $relations as relation - MATCH (source:Memory)-[r:$(relation.relationType)]->(target:Memory) - WHERE source.name = relation.source - AND target.name = relation.target - DELETE r - """ - self.neo4j_driver.execute_query( - query, - {"relations": [relation.model_dump() for relation in relations]} - ) + for relation in relations: + # Safely escape the relation type + rel_type = self._safe_relation_type(relation.relationType) + + # Use escaped relation type directly in query + query = f""" + MATCH (source:Memory)-[r:{rel_type}]->(target:Memory) + WHERE source.name = $source + AND target.name = $target + DELETE r + """ + + params = { + "source": relation.source, + "target": relation.target + } + + self.neo4j_driver.execute_query(query, params) async def read_graph(self) -> KnowledgeGraph: return await self.load_graph() @@ -196,7 +208,35 @@ async def search_nodes(self, query: str) -> KnowledgeGraph: async def find_nodes(self, names: List[str]) -> KnowledgeGraph: return await self.load_graph("name: (" + " ".join(names) + ")") - + + def _safe_relation_type(self, rel_type: str) -> str: + """ + Sanitizes and validates relation types for safe inclusion in Cypher queries. + + Neo4j doesn't support parameterized relation types in standard Cypher syntax, + which creates a potential Cypher injection vulnerability. While APOC procedures like + apoc.create.relationship would allow for parameterization, they may not be available + in all Neo4j instances and require additional installation and configuration. + + This method provides a defensive approach by: + 1. Removing all non-alphanumeric/underscore characters that could be used for injection + 2. Ensuring the relation type is never empty + 3. Ensuring the relation type starts with a letter + """ + # Neo4j relation types must be alphanumeric plus underscores + import re + safe_rel_type = re.sub(r'[^\w]', '_', rel_type) + + # Ensure it's not empty + if not safe_rel_type: + safe_rel_type = "UNKNOWN" + + # Ensure it starts with a letter + if not safe_rel_type[0].isalpha(): + safe_rel_type = "R_" + safe_rel_type + + return safe_rel_type + async def main(neo4j_uri: str, neo4j_user: str, neo4j_password: str): logger.info(f"Connecting to neo4j MCP Server with DB URL: {neo4j_uri}")