Skip to content

Commit ff1c6ee

Browse files
authored
Neo4jWriter improvements (#151)
* Update Cypher queries for nodes * Mypy * Set embeddings in same query (for nodes) * Fix e2e + mypy * Merge queries for relationships * Ruff * Unused imports * CHANGELOG update + elementId instead of elementID (seems to be the convention)
1 parent 89411ca commit ff1c6ee

File tree

6 files changed

+139
-119
lines changed

6 files changed

+139
-119
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
### Fixed
1919
- Resolved import issue with the Vertex AI Embeddings class.
20+
- Resolved issue where Neo4jWriter component would raise an error if the start or end node ID was not defined properly in the input.
2021

2122
### Changed
2223
- Moved the Embedder class to the neo4j_graphrag.embeddings directory for better organization alongside other custom embedders.
24+
- Neo4jWriter component now runs a single query to merge node and set its embeddings if any.
2325

2426
## 0.6.3
2527
### Changed

src/neo4j_graphrag/experimental/components/kg_writer.py

Lines changed: 29 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import asyncio
18+
import inspect
1819
import logging
1920
from abc import abstractmethod
2021
from typing import Any, Dict, Literal, Optional, Tuple
@@ -28,12 +29,6 @@
2829
Neo4jRelationship,
2930
)
3031
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
31-
from neo4j_graphrag.indexes import (
32-
async_upsert_vector,
33-
async_upsert_vector_on_relationship,
34-
upsert_vector,
35-
upsert_vector_on_relationship,
36-
)
3732
from neo4j_graphrag.neo4j_queries import UPSERT_NODE_QUERY, UPSERT_RELATIONSHIP_QUERY
3833

3934
logger = logging.getLogger(__name__)
@@ -102,15 +97,26 @@ def __init__(
10297
self.neo4j_database = neo4j_database
10398
self.max_concurrency = max_concurrency
10499

100+
def _db_setup(self) -> None:
101+
# create index on __Entity__.id
102+
self.driver.execute_query(
103+
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__Entity__) ON (n.id)"
104+
)
105+
106+
async def _async_db_setup(self) -> None:
107+
# create index on __Entity__.id
108+
await self.driver.execute_query(
109+
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__Entity__) ON (n.id)"
110+
)
111+
105112
def _get_node_query(self, node: Neo4jNode) -> Tuple[str, Dict[str, Any]]:
106113
# Create the initial node
107-
parameters = {"id": node.id}
108-
if node.properties:
109-
parameters.update(node.properties)
110-
properties = (
111-
"{" + ", ".join(f"{key}: ${key}" for key in parameters.keys()) + "}"
112-
)
113-
query = UPSERT_NODE_QUERY.format(label=node.label, properties=properties)
114+
parameters = {
115+
"id": node.id,
116+
"properties": node.properties or {},
117+
"embeddings": node.embedding_properties,
118+
}
119+
query = UPSERT_NODE_QUERY.format(label=node.label)
114120
return query, parameters
115121

116122
def _upsert_node(self, node: Neo4jNode) -> None:
@@ -120,18 +126,7 @@ def _upsert_node(self, node: Neo4jNode) -> None:
120126
node (Neo4jNode): The node to upsert into the database.
121127
"""
122128
query, parameters = self._get_node_query(node)
123-
result = self.driver.execute_query(query, parameters_=parameters)
124-
node_id = result.records[0]["elementID(n)"]
125-
# Add the embedding properties to the node
126-
if node.embedding_properties:
127-
for prop, vector in node.embedding_properties.items():
128-
upsert_vector(
129-
driver=self.driver,
130-
node_id=node_id,
131-
embedding_property=prop,
132-
vector=vector,
133-
neo4j_database=self.neo4j_database,
134-
)
129+
self.driver.execute_query(query, parameters_=parameters)
135130

136131
async def _async_upsert_node(
137132
self,
@@ -145,35 +140,18 @@ async def _async_upsert_node(
145140
"""
146141
async with sem:
147142
query, parameters = self._get_node_query(node)
148-
result = await self.driver.execute_query(query, parameters_=parameters)
149-
node_id = result.records[0]["elementID(n)"]
150-
# Add the embedding properties to the node
151-
if node.embedding_properties:
152-
for prop, vector in node.embedding_properties.items():
153-
await async_upsert_vector(
154-
driver=self.driver,
155-
node_id=node_id,
156-
embedding_property=prop,
157-
vector=vector,
158-
neo4j_database=self.neo4j_database,
159-
)
143+
await self.driver.execute_query(query, parameters_=parameters)
160144

161145
def _get_rel_query(self, rel: Neo4jRelationship) -> Tuple[str, Dict[str, Any]]:
162146
# Create the initial relationship
163147
parameters = {
164148
"start_node_id": rel.start_node_id,
165149
"end_node_id": rel.end_node_id,
150+
"properties": rel.properties or {},
151+
"embeddings": rel.embedding_properties,
166152
}
167-
if rel.properties:
168-
properties = (
169-
"{" + ", ".join(f"{key}: ${key}" for key in rel.properties.keys()) + "}"
170-
)
171-
parameters.update(rel.properties)
172-
else:
173-
properties = "{}"
174153
query = UPSERT_RELATIONSHIP_QUERY.format(
175154
type=rel.type,
176-
properties=properties,
177155
)
178156
return query, parameters
179157

@@ -184,18 +162,7 @@ def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
184162
rel (Neo4jRelationship): The relationship to upsert into the database.
185163
"""
186164
query, parameters = self._get_rel_query(rel)
187-
result = self.driver.execute_query(query, parameters_=parameters)
188-
rel_id = result.records[0]["elementID(r)"]
189-
# Add the embedding properties to the relationship
190-
if rel.embedding_properties:
191-
for prop, vector in rel.embedding_properties.items():
192-
upsert_vector_on_relationship(
193-
driver=self.driver,
194-
rel_id=rel_id,
195-
embedding_property=prop,
196-
vector=vector,
197-
neo4j_database=self.neo4j_database,
198-
)
165+
self.driver.execute_query(query, parameters_=parameters)
199166

200167
async def _async_upsert_relationship(
201168
self, rel: Neo4jRelationship, sem: asyncio.Semaphore
@@ -207,18 +174,7 @@ async def _async_upsert_relationship(
207174
"""
208175
async with sem:
209176
query, parameters = self._get_rel_query(rel)
210-
result = await self.driver.execute_query(query, parameters_=parameters)
211-
rel_id = result.records[0]["elementID(r)"]
212-
# Add the embedding properties to the relationship
213-
if rel.embedding_properties:
214-
for prop, vector in rel.embedding_properties.items():
215-
await async_upsert_vector_on_relationship(
216-
driver=self.driver,
217-
rel_id=rel_id,
218-
embedding_property=prop,
219-
vector=vector,
220-
neo4j_database=self.neo4j_database,
221-
)
177+
await self.driver.execute_query(query, parameters_=parameters)
222178

223179
@validate_call
224180
async def run(self, graph: Neo4jGraph) -> KGWriterModel:
@@ -228,7 +184,8 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
228184
graph (Neo4jGraph): The knowledge graph to upsert into the database.
229185
"""
230186
try:
231-
if isinstance(self.driver, neo4j.AsyncDriver):
187+
if inspect.iscoroutinefunction(self.driver.execute_query):
188+
await self._async_db_setup()
232189
sem = asyncio.Semaphore(self.max_concurrency)
233190
node_tasks = [
234191
self._async_upsert_node(node, sem) for node in graph.nodes
@@ -241,6 +198,8 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
241198
]
242199
await asyncio.gather(*rel_tasks)
243200
else:
201+
self._db_setup()
202+
244203
for node in graph.nodes:
245204
self._upsert_node(node)
246205

src/neo4j_graphrag/neo4j_queries.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,28 @@
4141
"YIELD node, score"
4242
)
4343

44-
UPSERT_NODE_QUERY = "MERGE (n:`{label}` {properties}) RETURN elementID(n)"
44+
UPSERT_NODE_QUERY = (
45+
"MERGE (n:__Entity__ {{id: $id}}) "
46+
"WITH n SET n:`{label}`, n += $properties "
47+
"WITH n CALL {{ "
48+
"WITH n WITH n WHERE $embeddings IS NOT NULL "
49+
"UNWIND keys($embeddings) as emb "
50+
"CALL db.create.setNodeVectorProperty(n, emb, $embeddings[emb]) "
51+
"}} "
52+
"RETURN elementId(n)"
53+
)
4554

4655
UPSERT_RELATIONSHIP_QUERY = (
47-
"MATCH (start {{ id: $start_node_id }}), (end {{ id: $end_node_id }}) "
48-
"MERGE (start)-[r:{type} {properties}]->(end) "
49-
"RETURN elementID(r)"
56+
"MATCH (start {{ id: $start_node_id }}) "
57+
"MATCH (end {{ id: $end_node_id }}) "
58+
"MERGE (start)-[r:{type}]->(end) "
59+
"WITH r SET r += $properties "
60+
"WITH r CALL {{ "
61+
"WITH r WITH r WHERE $embeddings IS NOT NULL "
62+
"UNWIND keys($embeddings) as emb "
63+
"CALL db.create.setRelationshipVectorProperty(r, emb, $embeddings[emb]) "
64+
"}} "
65+
"RETURN elementId(r)"
5066
)
5167

5268
UPSERT_VECTOR_ON_NODE_QUERY = (

tests/e2e/test_kg_writer_component_e2e.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None:
5454
assert "a" and "b" and "r" in record.keys()
5555

5656
node_a = record["a"]
57-
assert start_node.label == list(node_a.labels)[0]
57+
assert start_node.label in list(node_a.labels)
5858
assert start_node.id == str(node_a.get("id"))
5959
if start_node.properties:
6060
for key, val in start_node.properties.items():
@@ -66,7 +66,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None:
6666
assert node_a.get(key) == [1.0, 2.0, 3.0]
6767

6868
node_b = record["b"]
69-
assert end_node.label == list(node_b.labels)[0]
69+
assert end_node.label in list(node_b.labels)
7070
assert end_node.id == str(node_b.get("id"))
7171
if end_node.properties:
7272
for key, val in end_node.properties.items():

tests/unit/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ def driver() -> MagicMock:
3434
return MagicMock(spec=neo4j.Driver)
3535

3636

37+
@pytest.fixture(scope="function")
38+
def async_driver() -> MagicMock:
39+
return MagicMock(spec=neo4j.AsyncDriver)
40+
41+
3742
@pytest.fixture(scope="function")
3843
def embedder() -> MagicMock:
3944
return MagicMock(spec=Embedder)

0 commit comments

Comments
 (0)