Skip to content

Fix create_relations multiple add bug and linear command parser #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 89 additions & 83 deletions servers/mcp-neo4j-memory/src/mcp_neo4j_memory/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, neo4j_driver):

def create_fulltext_index(self):
try:
# TODO ,
# TODO ,
query = """
CREATE FULLTEXT INDEX search IF NOT EXISTS FOR (m:Memory) ON EACH [m.name, m.type, m.observations];
"""
Expand All @@ -64,26 +64,26 @@ async def load_graph(self, filter_query="*"):
CALL db.index.fulltext.queryNodes('search', $filter) yield node as entity, score
OPTIONAL MATCH (entity)-[r]-(other)
RETURN collect(distinct {
name: entity.name,
type: entity.type,
name: entity.name,
type: entity.type,
observations: entity.observations
}) as nodes,
collect(distinct {
source: startNode(r).name,
target: endNode(r).name,
source: startNode(r).name,
target: endNode(r).name,
relationType: type(r)
}) as relations
"""

result = self.neo4j_driver.execute_query(query, {"filter": filter_query})

if not result.records:
return KnowledgeGraph(entities=[], relations=[])

record = result.records[0]
nodes = record.get('nodes')
rels = record.get('relations')

entities = [
Entity(
name=node.get('name'),
Expand All @@ -92,7 +92,7 @@ async def load_graph(self, filter_query="*"):
)
for node in nodes if node.get('name')
]

relations = [
Relation(
source=rel.get('source'),
Expand All @@ -101,10 +101,10 @@ async def load_graph(self, filter_query="*"):
)
for rel in rels if rel.get('source') and rel.get('target') and rel.get('relationType')
]

logger.debug(f"Loaded entities: {entities}")
logger.debug(f"Loaded relations: {relations}")

return KnowledgeGraph(entities=entities, relations=relations)

async def create_entities(self, entities: List[Entity]) -> List[Entity]:
Expand All @@ -114,39 +114,37 @@ async def create_entities(self, entities: List[Entity]) -> List[Entity]:
SET e += entity {.type, .observations}
SET e:$(entity.type)
"""

entities_data = [entity.model_dump() for entity in entities]
self.neo4j_driver.execute_query(query, {"entities": entities_data})
return entities

async def create_relations(self, relations: List[Relation]) -> List[Relation]:
for relation in relations:
query = """
UNWIND $relations as relation
MATCH (from:Memory),(to:Memory)
WHERE from.name = relation.source
AND to.name = relation.target
MERGE (from)-[r:$(relation.relationType)]->(to)
"""

self.neo4j_driver.execute_query(
query,
{"relations": [relation.model_dump() for relation in relations]}
)

query = """
UNWIND $relations AS relation
MATCH (from:Memory { name: relation.source })
MATCH (to:Memory { name: relation.target })
CALL apoc.merge.relationship(from, relation.relationType, {}, {}, to) YIELD rel
RETURN rel
"""
self.neo4j_driver.execute_query(
query,
{"relations": [rel.model_dump() for rel in relations]}
)

return relations

async def add_observations(self, observations: List[ObservationAddition]) -> List[Dict[str, Any]]:
query = """
UNWIND $observations as obs
UNWIND $observations as obs
MATCH (e:Memory { name: obs.entityName })
WITH e, [o in obs.contents WHERE NOT o IN e.observations] as new
SET e.observations = coalesce(e.observations,[]) + new
RETURN e.name as name, new
"""

result = self.neo4j_driver.execute_query(
query,
query,
{"observations": [obs.model_dump() for obs in observations]}
)

Expand All @@ -159,17 +157,17 @@ async def delete_entities(self, entity_names: List[str]) -> None:
MATCH (e:Memory { name: name })
DETACH DELETE e
"""

self.neo4j_driver.execute_query(query, {"entities": entity_names})

async def delete_observations(self, deletions: List[ObservationDeletion]) -> None:
query = """
UNWIND $deletions as d
UNWIND $deletions as d
MATCH (e:Memory { name: d.entityName })
SET e.observations = [o in coalesce(e.observations,[]) WHERE NOT o IN d.observations]
"""
self.neo4j_driver.execute_query(
query,
query,
{
"deletions": [deletion.model_dump() for deletion in deletions]
}
Expand All @@ -184,7 +182,7 @@ async def delete_relations(self, relations: List[Relation]) -> None:
DELETE r
"""
self.neo4j_driver.execute_query(
query,
query,
{"relations": [relation.model_dump() for relation in relations]}
)

Expand All @@ -205,7 +203,7 @@ async def main(neo4j_uri: str, neo4j_user: str, neo4j_password: str):
neo4j_uri,
auth=(neo4j_user, neo4j_password)
)

# Verify connection
try:
neo4j_driver.verify_connectivity()
Expand All @@ -216,7 +214,7 @@ async def main(neo4j_uri: str, neo4j_user: str, neo4j_password: str):

# Initialize memory
memory = Neo4jMemory(neo4j_driver)

# Create MCP server
server = Server("mcp-neo4j-memory")

Expand Down Expand Up @@ -399,55 +397,63 @@ async def handle_list_tools() -> List[types.Tool]:
@server.call_tool()
async def handle_call_tool(
name: str, arguments: Dict[str, Any] | None
) -> List[types.TextContent | types.ImageContent]:
) -> List[types.TextContent]:
try:
if not arguments:
raise ValueError(f"No arguments provided for tool: {name}")

if name == "create_entities":
entities = [Entity(**entity) for entity in arguments.get("entities", [])]
result = await memory.create_entities(entities)
return [types.TextContent(type="text", text=json.dumps([e.model_dump() for e in result], indent=2))]

elif name == "create_relations":
relations = [Relation(**relation) for relation in arguments.get("relations", [])]
result = await memory.create_relations(relations)
return [types.TextContent(type="text", text=json.dumps([r.model_dump() for r in result], indent=2))]

elif name == "add_observations":
observations = [ObservationAddition(**obs) for obs in arguments.get("observations", [])]
result = await memory.add_observations(observations)
return [types.TextContent(type="text", text=json.dumps(result, indent=2))]

elif name == "delete_entities":
await memory.delete_entities(arguments.get("entityNames", []))
return [types.TextContent(type="text", text="Entities deleted successfully")]

elif name == "delete_observations":
deletions = [ObservationDeletion(**deletion) for deletion in arguments.get("deletions", [])]
await memory.delete_observations(deletions)
return [types.TextContent(type="text", text="Observations deleted successfully")]

elif name == "delete_relations":
relations = [Relation(**relation) for relation in arguments.get("relations", [])]
await memory.delete_relations(relations)
return [types.TextContent(type="text", text="Relations deleted successfully")]

elif name == "read_graph":
result = await memory.read_graph()
return [types.TextContent(type="text", text=json.dumps(result.model_dump(), indent=2))]

elif name == "search_nodes":
result = await memory.search_nodes(arguments.get("query", ""))
return [types.TextContent(type="text", text=json.dumps(result.model_dump(), indent=2))]

elif name == "find_nodes":
result = await memory.find_nodes(arguments.get("names", []))
return [types.TextContent(type="text", text=json.dumps(result.model_dump(), indent=2))]

else:
# Prepare args for handler
args = arguments or {}

# Dispatch table: tool name -> (method, arg_key, model_cls, formatter)
dispatch = {
"create_entities": (
memory.create_entities, "entities", Entity,
lambda res: [types.TextContent(type="text", text=json.dumps([e.model_dump() for e in res], indent=2))]
),
"create_relations": (
memory.create_relations, "relations", Relation,
lambda res: [types.TextContent(type="text", text=json.dumps([r.model_dump() for r in res], indent=2))]
),
"add_observations": (
memory.add_observations, "observations", ObservationAddition,
lambda res: [types.TextContent(type="text", text=json.dumps(res, indent=2))]
),
"delete_entities": (
memory.delete_entities, "entityNames", None,
lambda _: [types.TextContent(type="text", text="Entities deleted successfully")]
),
"delete_observations": (
memory.delete_observations, "deletions", ObservationDeletion,
lambda _: [types.TextContent(type="text", text="Observations deleted successfully")]
),
"delete_relations": (
memory.delete_relations, "relations", Relation,
lambda _: [types.TextContent(type="text", text="Relations deleted successfully")]
),
"read_graph": (
memory.read_graph, None, None,
lambda res: [types.TextContent(type="text", text=json.dumps(res.model_dump(), indent=2))]
),
"search_nodes": (
memory.search_nodes, "query", None,
lambda res: [types.TextContent(type="text", text=json.dumps(res.model_dump(), indent=2))]
),
"find_nodes": (
memory.find_nodes, "names", None,
lambda res: [types.TextContent(type="text", text=json.dumps(res.model_dump(), indent=2))]
),
}
if name not in dispatch:
raise ValueError(f"Unknown tool: {name}")

method, arg_key, model_cls, formatter = dispatch[name]
# Prepare parameters for memory method
if arg_key:
raw = args.get(arg_key, [])
param = [model_cls(**item) for item in raw] if model_cls else raw
result = await method(param)
else:
result = await method()
# Format and return
return formatter(result)

except Exception as e:
logger.error(f"Error handling tool call: {e}")
return [types.TextContent(type="text", text=f"Error: {str(e)}")]
Expand Down
Loading