Skip to content

Commit 6bfa2f1

Browse files
authored
Neo4jWriter: batched upsert (#152)
* Update Cypher queries for nodes * Mypy * Set embeddings in same query (for nodes) * Fix e2e + mypy * Merge queries for relationships * Ruff * Unused imports * CHANGELOG update + elementId instead of elementID (seems to be the convention) * Fix relationship query * Batch insert nodes * Mypy * We can use CREATE since IDs are unique * Batch relationship insert * Docstrings, ruff * WIP: test multiple embeddings * Queries * CHANGELOG and doc * mypy * Do not assign an __Entity__ label to the lexical graph nodes * Fix import * Fix import * Fix CHANGELOG
1 parent 1f5eaa1 commit 6bfa2f1

File tree

8 files changed

+268
-147
lines changed

8 files changed

+268
-147
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
- `Text2CypherTemplate` and `RAGTemplate` prompt templates now require `query_text` arg and will error if it is not present. Previous `query_text` aliases may be used, but will warn of deprecation.
2525
- Resolved issue where Neo4jWriter component would raise an error if the start or end node ID was not defined properly in the input.
2626
- Resolved issue where relationship types was not escaped in the insert Cypher query.
27-
- Improved query performance in Neo4jWriter.
27+
- Improved query performance in Neo4jWriter: created nodes now have a generic `__KGBuilder__` label and an index is created on the `__KGBuilder__.id` property. Moreover, insertion queries are now batched. Batch size can be controlled using the `batch_size` parameter in the `Neo4jWriter` component.
2828

2929
### Changed
3030
- Moved the Embedder class to the neo4j_graphrag.embeddings directory for better organization alongside other custom embedders.
3131
- Removed query argument from the GraphRAG class' `.search` method; users must now use `query_text`.
3232
- Neo4jWriter component now runs a single query to merge node and set its embeddings if any.
33+
- Nodes created by the `Neo4jWriter` now have an extra `__KGBuilder__` label. Nodes from the entity graph also have an `__Entity__` label.
3334
- Dropped support for Python 3.8 (end of life).
3435

3536
## 0.6.3

docs/source/user_guide_kg_builder.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,13 @@ to a Neo4j database:
389389
graph = Neo4jGraph(nodes=[], relationships=[])
390390
await writer.run(graph)
391391
392-
See :ref:`neo4jgraph` for the description of the input type.
392+
To improve insert performances, it is possible to act on two parameters:
393+
394+
- `batch_size`: the number of nodes/relationships to be processed in each batch (default is 1000).
395+
- `max_concurrency`: the max number of concurrent queries (default is 5).
396+
397+
See :ref:`neo4jgraph`.
398+
393399

394400
It is possible to create a custom writer using the `KGWriter` interface:
395401

@@ -419,4 +425,4 @@ It is possible to create a custom writer using the `KGWriter` interface:
419425
The `validate_call` decorator is required when the input parameter contain a `pydantic` model.
420426

421427

422-
See :ref:`kgwritermodel` and :ref:`kgwriter` in API reference.
428+
See :ref:`kgwritermodel` and :ref:`kgwriter` in API reference.

src/neo4j_graphrag/experimental/components/kg_writer.py

Lines changed: 71 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@
1818
import inspect
1919
import logging
2020
from abc import abstractmethod
21-
from typing import Any, Dict, Literal, Optional, Tuple
21+
from typing import Any, Generator, Literal, Optional
2222

2323
import neo4j
2424
from pydantic import validate_call
2525

26+
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
27+
CHUNK_NODE_LABEL,
28+
DOCUMENT_NODE_LABEL,
29+
)
2630
from neo4j_graphrag.experimental.components.types import (
2731
Neo4jGraph,
2832
Neo4jNode,
@@ -34,14 +38,25 @@
3438
logger = logging.getLogger(__name__)
3539

3640

41+
def batched(rows: list[Any], batch_size: int) -> Generator[list[Any], None, None]:
42+
index = 0
43+
for i in range(0, len(rows), batch_size):
44+
start = i
45+
end = min(start + batch_size, len(rows))
46+
batch = rows[start:end]
47+
yield batch
48+
index += 1
49+
50+
3751
class KGWriterModel(DataModel):
3852
"""Data model for the output of the Knowledge Graph writer.
3953
4054
Attributes:
41-
status (Literal["SUCCESS", "FAILURE"]): Whether or not the write operation was successful.
55+
status (Literal["SUCCESS", "FAILURE"]): Whether the write operation was successful.
4256
"""
4357

4458
status: Literal["SUCCESS", "FAILURE"]
59+
metadata: Optional[dict[str, Any]] = None
4560

4661

4762
class KGWriter(Component):
@@ -91,90 +106,85 @@ def __init__(
91106
self,
92107
driver: neo4j.driver,
93108
neo4j_database: Optional[str] = None,
109+
batch_size: int = 1000,
94110
max_concurrency: int = 5,
95111
):
96112
self.driver = driver
97113
self.neo4j_database = neo4j_database
114+
self.batch_size = batch_size
98115
self.max_concurrency = max_concurrency
99116

100117
def _db_setup(self) -> None:
101118
# create index on __Entity__.id
119+
# used when creating the relationships
102120
self.driver.execute_query(
103-
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__Entity__) ON (n.id)"
121+
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)"
104122
)
105123

106124
async def _async_db_setup(self) -> None:
107125
# create index on __Entity__.id
126+
# used when creating the relationships
108127
await self.driver.execute_query(
109-
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__Entity__) ON (n.id)"
128+
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)"
110129
)
111130

112-
def _get_node_query(self, node: Neo4jNode) -> Tuple[str, Dict[str, Any]]:
113-
# Create the initial node
114-
parameters = {
115-
"id": node.id,
116-
"properties": node.properties or {},
117-
"embeddings": node.embedding_properties,
118-
}
119-
query = UPSERT_NODE_QUERY.format(label=node.label)
120-
return query, parameters
121-
122-
def _upsert_node(self, node: Neo4jNode) -> None:
131+
@staticmethod
132+
def _nodes_to_rows(nodes: list[Neo4jNode]) -> list[dict[str, Any]]:
133+
rows = []
134+
for node in nodes:
135+
labels = [node.label]
136+
if node.label not in (CHUNK_NODE_LABEL, DOCUMENT_NODE_LABEL):
137+
labels.append("__Entity__")
138+
row = node.model_dump()
139+
row["labels"] = labels
140+
rows.append(row)
141+
return rows
142+
143+
def _upsert_nodes(self, nodes: list[Neo4jNode]) -> None:
123144
"""Upserts a single node into the Neo4j database."
124145
125146
Args:
126-
node (Neo4jNode): The node to upsert into the database.
147+
nodes (list[Neo4jNode]): The nodes batch to upsert into the database.
127148
"""
128-
query, parameters = self._get_node_query(node)
129-
self.driver.execute_query(query, parameters_=parameters)
149+
parameters = {"rows": self._nodes_to_rows(nodes)}
150+
self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)
130151

131-
async def _async_upsert_node(
152+
async def _async_upsert_nodes(
132153
self,
133-
node: Neo4jNode,
154+
nodes: list[Neo4jNode],
134155
sem: asyncio.Semaphore,
135156
) -> None:
136157
"""Asynchronously upserts a single node into the Neo4j database."
137158
138159
Args:
139-
node (Neo4jNode): The node to upsert into the database.
160+
nodes (list[Neo4jNode]): The nodes batch to upsert into the database.
140161
"""
141162
async with sem:
142-
query, parameters = self._get_node_query(node)
143-
await self.driver.execute_query(query, parameters_=parameters)
144-
145-
def _get_rel_query(self, rel: Neo4jRelationship) -> Tuple[str, Dict[str, Any]]:
146-
# Create the initial relationship
147-
parameters = {
148-
"start_node_id": rel.start_node_id,
149-
"end_node_id": rel.end_node_id,
150-
"properties": rel.properties or {},
151-
"embeddings": rel.embedding_properties,
152-
}
153-
query = UPSERT_RELATIONSHIP_QUERY.format(
154-
type=rel.type,
155-
)
156-
return query, parameters
163+
parameters = {"rows": self._nodes_to_rows(nodes)}
164+
await self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)
157165

158-
def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
166+
def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
159167
"""Upserts a single relationship into the Neo4j database.
160168
161169
Args:
162-
rel (Neo4jRelationship): The relationship to upsert into the database.
170+
rels (list[Neo4jRelationship]): The relationships batch to upsert into the database.
163171
"""
164-
query, parameters = self._get_rel_query(rel)
165-
self.driver.execute_query(query, parameters_=parameters)
172+
parameters = {"rows": [rel.model_dump() for rel in rels]}
173+
self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters)
166174

167-
async def _async_upsert_relationship(
168-
self, rel: Neo4jRelationship, sem: asyncio.Semaphore
175+
async def _async_upsert_relationships(
176+
self, rels: list[Neo4jRelationship], sem: asyncio.Semaphore
169177
) -> None:
170178
"""Asynchronously upserts a single relationship into the Neo4j database.
171179
172180
Args:
173-
rel (Neo4jRelationship): The relationship to upsert into the database.
181+
rels (list[Neo4jRelationship]): The relationships batch to upsert into the database.
174182
"""
175183
async with sem:
176-
query, parameters = self._get_rel_query(rel)
177-
await self.driver.execute_query(query, parameters_=parameters)
184+
parameters = {"rows": [rel.model_dump() for rel in rels]}
185+
await self.driver.execute_query(
186+
UPSERT_RELATIONSHIP_QUERY, parameters_=parameters
187+
)
178188

179189
@validate_call
180190
async def run(self, graph: Neo4jGraph) -> KGWriterModel:
@@ -188,25 +198,32 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
188198
await self._async_db_setup()
189199
sem = asyncio.Semaphore(self.max_concurrency)
190200
node_tasks = [
191-
self._async_upsert_node(node, sem) for node in graph.nodes
201+
self._async_upsert_nodes(batch, sem)
202+
for batch in batched(graph.nodes, self.batch_size)
192203
]
193204
await asyncio.gather(*node_tasks)
194205

195206
rel_tasks = [
196-
self._async_upsert_relationship(rel, sem)
197-
for rel in graph.relationships
207+
self._async_upsert_relationships(batch, sem)
208+
for batch in batched(graph.relationships, self.batch_size)
198209
]
199210
await asyncio.gather(*rel_tasks)
200211
else:
201212
self._db_setup()
202213

203-
for node in graph.nodes:
204-
self._upsert_node(node)
214+
for batch in batched(graph.nodes, self.batch_size):
215+
self._upsert_nodes(batch)
205216

206-
for rel in graph.relationships:
207-
self._upsert_relationship(rel)
217+
for batch in batched(graph.relationships, self.batch_size):
218+
self._upsert_relationships(batch)
208219

209-
return KGWriterModel(status="SUCCESS")
220+
return KGWriterModel(
221+
status="SUCCESS",
222+
metadata={
223+
"node_count": len(graph.nodes),
224+
"relationship_count": len(graph.relationships),
225+
},
226+
)
210227
except neo4j.exceptions.ClientError as e:
211228
logger.exception(e)
212-
return KGWriterModel(status="FAILURE")
229+
return KGWriterModel(status="FAILURE", metadata={"error": str(e)})

src/neo4j_graphrag/experimental/components/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ class Neo4jNode(BaseModel):
5151
Attributes:
5252
id (str): The ID of the node.
5353
label (str): The label of the node.
54-
properties (Optional[dict[str, Any]]): A dictionary of properties attached to the node.
54+
properties (dict[str, Any]): A dictionary of properties attached to the node.
5555
embedding_properties (Optional[dict[str, list[float]]]): A list of embedding properties attached to the node.
5656
"""
5757

5858
id: str
5959
label: str
60-
properties: Optional[dict[str, Any]] = None
60+
properties: dict[str, Any] = {}
6161
embedding_properties: Optional[dict[str, list[float]]] = None
6262

6363
@field_validator("properties", "embedding_properties")
@@ -77,14 +77,14 @@ class Neo4jRelationship(BaseModel):
7777
start_node_id (str): The ID of the start node.
7878
end_node_id (str): The ID of the end node.
7979
type (str): The relationship type.
80-
properties (Optional[dict[str, Any]]): A dictionary of properties attached to the relationship.
80+
properties (dict[str, Any]): A dictionary of properties attached to the relationship.
8181
embedding_properties (Optional[dict[str, list[float]]]): A list of embedding properties attached to the relationship.
8282
"""
8383

8484
start_node_id: str
8585
end_node_id: str
8686
type: str
87-
properties: Optional[dict[str, Any]] = None
87+
properties: dict[str, Any] = {}
8888
embedding_properties: Optional[dict[str, list[float]]] = None
8989

9090

src/neo4j_graphrag/neo4j_queries.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,27 +42,31 @@
4242
)
4343

4444
UPSERT_NODE_QUERY = (
45-
"MERGE (n:__Entity__ {{id: $id}}) "
46-
"WITH n SET n:`{label}`, n += $properties "
47-
"WITH n CALL {{ "
48-
"WITH n WITH n WHERE $embeddings IS NOT NULL "
49-
"UNWIND keys($embeddings) as emb "
50-
"CALL db.create.setNodeVectorProperty(n, emb, $embeddings[emb]) "
51-
"}} "
45+
"UNWIND $rows AS row "
46+
"CREATE (n:__KGBuilder__ {id: row.id}) "
47+
"SET n += row.properties "
48+
"WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node "
49+
"WITH node as n, row CALL { "
50+
"WITH n, row WITH n, row WHERE row.embedding_properties IS NOT NULL "
51+
"UNWIND keys(row.embedding_properties) as emb "
52+
"CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) "
53+
"RETURN count(*) as nbEmb "
54+
"} "
5255
"RETURN elementId(n)"
5356
)
5457

5558
UPSERT_RELATIONSHIP_QUERY = (
56-
"MATCH (start:__Entity__ {{ id: $start_node_id }}) "
57-
"MATCH (end:__Entity__ {{ id: $end_node_id }}) "
58-
"MERGE (start)-[r:`{type}`]->(end) "
59-
"WITH r SET r += $properties "
60-
"WITH r CALL {{ "
61-
"WITH r WITH r WHERE $embeddings IS NOT NULL "
62-
"UNWIND keys($embeddings) as emb "
63-
"CALL db.create.setRelationshipVectorProperty(r, emb, $embeddings[emb]) "
64-
"}} "
65-
"RETURN elementId(r)"
59+
"UNWIND $rows as row "
60+
"MATCH (start:__KGBuilder__ {id: row.start_node_id}) "
61+
"MATCH (end:__KGBuilder__ {id: row.end_node_id}) "
62+
"WITH start, end, row "
63+
"CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel "
64+
"WITH rel, row CALL { "
65+
"WITH rel, row WITH rel, row WHERE row.embedding_properties IS NOT NULL "
66+
"UNWIND keys(row.embedding_properties) as emb "
67+
"CALL db.create.setRelationshipVectorProperty(rel, emb, row.embedding_properties[emb]) "
68+
"} "
69+
"RETURN elementId(rel)"
6670
)
6771

6872
UPSERT_VECTOR_ON_NODE_QUERY = (

tests/e2e/test_kg_builder_pipeline_e2e.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,12 @@ async def test_pipeline_builder_happy_path(
256256
# result must be success
257257
assert isinstance(res, PipelineResult)
258258
assert res.run_id is not None
259-
assert res.result == {"writer": {"status": "SUCCESS"}}
259+
assert res.result == {
260+
"writer": {
261+
"status": "SUCCESS",
262+
"metadata": {"node_count": 7, "relationship_count": 10},
263+
}
264+
}
260265
# check component's results
261266
chunks = await kg_builder_pipeline.store.get_result_for_component(
262267
res.run_id, "splitter"
@@ -290,6 +295,7 @@ async def test_pipeline_builder_happy_path(
290295
}
291296
# then check content of neo4j db
292297
created_nodes = driver.execute_query("MATCH (n) RETURN n")
298+
print(created_nodes.records)
293299
assert len(created_nodes.records) == 7
294300
created_rels = driver.execute_query("MATCH ()-[r]->() RETURN r")
295301
assert len(created_rels.records) == 10
@@ -462,7 +468,12 @@ async def test_pipeline_builder_failing_chunk_do_not_raise(
462468
# result must be success
463469
assert isinstance(res, PipelineResult)
464470
assert res.run_id is not None
465-
assert res.result == {"writer": {"status": "SUCCESS"}}
471+
assert res.result == {
472+
"writer": {
473+
"status": "SUCCESS",
474+
"metadata": {"node_count": 6, "relationship_count": 7},
475+
}
476+
}
466477
# check component's results
467478
chunks = await kg_builder_pipeline.store.get_result_for_component(
468479
res.run_id, "splitter"

0 commit comments

Comments
 (0)