Skip to content

Commit 4c32d1d

Browse files
committed
Rename property, clean db
1 parent 75c445d commit 4c32d1d

File tree

5 files changed

+142
-133
lines changed

5 files changed

+142
-133
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
#### Other
3030

31-
- The `id` property on `__KG_Builder__` nodes is removed.
31+
- The reserved `id` property on `__KGBuilder__` nodes is removed.
3232
- The `chunk_index` property on `__Entity__` nodes is removed. Use the `FROM_CHUNK` relationship instead.
3333

3434
## 1.7.0

src/neo4j_graphrag/experimental/components/kg_writer.py

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@
2929
)
3030
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
3131
from neo4j_graphrag.neo4j_queries import (
32-
UPSERT_NODE_QUERY,
33-
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE,
34-
UPSERT_RELATIONSHIP_QUERY,
35-
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
32+
upsert_node_query,
33+
upsert_relationship_query,
34+
db_cleaning_query,
3635
)
3736
from neo4j_graphrag.utils.version_utils import (
3837
get_version,
@@ -117,16 +116,19 @@ def __init__(
117116
driver: neo4j.Driver,
118117
neo4j_database: Optional[str] = None,
119118
batch_size: int = 1000,
119+
clean_db: bool = True,
120120
):
121121
self.driver = driver_config.override_user_agent(driver)
122122
self.neo4j_database = neo4j_database
123123
self.batch_size = batch_size
124+
self._clean_db = clean_db
124125
version_tuple, _, _ = get_version(self.driver, self.neo4j_database)
125126
self.is_version_5_23_or_above = is_version_5_23_or_above(version_tuple)
126127

127128
def _db_setup(self) -> None:
128-
# not used for now
129-
pass
129+
self.driver.execute_query("""
130+
CREATE INDEX __entity__tmp_internal_id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.__tmp_internal_id)
131+
""")
130132

131133
@staticmethod
132134
def _nodes_to_rows(
@@ -144,62 +146,53 @@ def _nodes_to_rows(
144146

145147
def _upsert_nodes(
146148
self, nodes: list[Neo4jNode], lexical_graph_config: LexicalGraphConfig
147-
) -> dict[str, str]:
148-
"""Upserts a single node into the Neo4j database."
149+
) -> None:
150+
"""Upserts a batch of nodes into the Neo4j database.
149151
150152
Args:
151153
nodes (list[Neo4jNode]): The nodes batch to upsert into the database.
152154
"""
153155
parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)}
154-
query = (
155-
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE
156-
if self.is_version_5_23_or_above
157-
else UPSERT_NODE_QUERY
156+
query = upsert_node_query(
157+
support_variable_scope_clause=self.is_version_5_23_or_above
158158
)
159-
records, _, _ = self.driver.execute_query(
159+
self.driver.execute_query(
160160
query,
161161
parameters_=parameters,
162162
database_=self.neo4j_database,
163163
)
164-
return {r["_internal_id"]: r["element_id"] for r in records}
164+
return None
165165

166166
@staticmethod
167167
def _relationships_to_rows(
168-
relationships: list[Neo4jRelationship], node_id_mapping: dict[str, str]
168+
relationships: list[Neo4jRelationship],
169169
) -> list[dict[str, Any]]:
170-
return [
171-
{
172-
**relationship.model_dump(),
173-
"start_node_element_id": node_id_mapping.get(
174-
relationship.start_node_id, ""
175-
),
176-
"end_node_element_id": node_id_mapping.get(
177-
relationship.end_node_id, ""
178-
),
179-
}
180-
for relationship in relationships
181-
]
182-
183-
def _upsert_relationships(
184-
self, rels: list[Neo4jRelationship], node_id_mapping: dict[str, str]
185-
) -> None:
186-
"""Upserts a single relationship into the Neo4j database.
170+
return [relationship.model_dump() for relationship in relationships]
171+
172+
def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
173+
"""Upserts a batch of relationships into the Neo4j database.
187174
188175
Args:
189176
rels (list[Neo4jRelationship]): The relationships batch to upsert into the database.
190177
"""
191-
parameters = {"rows": self._relationships_to_rows(rels, node_id_mapping)}
192-
query = (
193-
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE
194-
if self.is_version_5_23_or_above
195-
else UPSERT_RELATIONSHIP_QUERY
178+
parameters = {"rows": self._relationships_to_rows(rels)}
179+
query = upsert_relationship_query(
180+
support_variable_scope_clause=self.is_version_5_23_or_above
196181
)
197182
self.driver.execute_query(
198183
query,
199184
parameters_=parameters,
200185
database_=self.neo4j_database,
201186
)
202187

188+
def _db_cleaning(self) -> None:
189+
query = db_cleaning_query(
190+
support_variable_scope_clause=self.is_version_5_23_or_above,
191+
batch_size=self.batch_size,
192+
)
193+
with self.driver.session() as session:
194+
session.run(query)
195+
203196
@validate_call
204197
async def run(
205198
self,
@@ -215,14 +208,14 @@ async def run(
215208
try:
216209
self._db_setup()
217210

218-
node_id_mapping = {}
219-
220211
for batch in batched(graph.nodes, self.batch_size):
221-
batch_mapping = self._upsert_nodes(batch, lexical_graph_config)
222-
node_id_mapping.update(batch_mapping)
212+
self._upsert_nodes(batch, lexical_graph_config)
223213

224214
for batch in batched(graph.relationships, self.batch_size):
225-
self._upsert_relationships(batch, node_id_mapping)
215+
self._upsert_relationships(batch)
216+
217+
if self._clean_db:
218+
self._db_cleaning()
226219

227220
return KGWriterModel(
228221
status="SUCCESS",

src/neo4j_graphrag/neo4j_queries.py

Lines changed: 84 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -52,61 +52,89 @@
5252
"YIELD node, score"
5353
)
5454

55-
UPSERT_NODE_QUERY = (
56-
"UNWIND $rows AS row "
57-
"CREATE (n:__KGBuilder__) "
58-
"SET n += row.properties "
59-
"WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node "
60-
"WITH node as n, row CALL { "
61-
"WITH n, row WITH n, row WHERE row.embedding_properties IS NOT NULL "
62-
"UNWIND keys(row.embedding_properties) as emb "
63-
"CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) "
64-
"RETURN count(*) as nbEmb "
65-
"} "
66-
"RETURN row.id as _internal_id, elementId(n) as element_id"
67-
)
6855

69-
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE = (
70-
"UNWIND $rows AS row "
71-
"CREATE (n:__KGBuilder__) "
72-
"SET n += row.properties "
73-
"WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node "
74-
"WITH node as n, row CALL (n, row) { "
75-
"WITH n, row WITH n, row WHERE row.embedding_properties IS NOT NULL "
76-
"UNWIND keys(row.embedding_properties) as emb "
77-
"CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) "
78-
"RETURN count(*) as nbEmb "
79-
"} "
80-
"RETURN row.id as _internal_id, elementId(n) as element_id"
81-
)
56+
def _call_subquery_syntax(
57+
support_variable_scope_clause: bool, variable_list: list[str]
58+
) -> str:
59+
"""A helper function to return the CALL subquery syntax:
60+
- Either CALL { WITH <variables>
61+
- or CALL (variables) {
62+
"""
63+
variables = ",".join(variable_list)
64+
if support_variable_scope_clause:
65+
return f"CALL ({variables}) {{ "
66+
if variables:
67+
return f"CALL {{ WITH {variables} "
68+
return "CALL { "
69+
70+
71+
def upsert_node_query(support_variable_scope_clause: bool) -> str:
72+
"""Build the Cypher query to upsert a batch of nodes:
73+
- Create the new node
74+
- Set its label(s) and properties
75+
- Set its embedding properties if any
76+
- Return the node elementId
77+
"""
78+
call_prefix = _call_subquery_syntax(
79+
support_variable_scope_clause, variable_list=["n", "row"]
80+
)
81+
return (
82+
"UNWIND $rows AS row "
83+
"CREATE (n:__KGBuilder__ {__tmp_internal_id: row.id}) "
84+
"SET n += row.properties "
85+
"WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node "
86+
"WITH node as n, row "
87+
f"{call_prefix} "
88+
"WITH n, row WHERE row.embedding_properties IS NOT NULL "
89+
"UNWIND keys(row.embedding_properties) as emb "
90+
"CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) "
91+
"RETURN count(*) as nbEmb "
92+
"} "
93+
"RETURN elementId(n) as element_id"
94+
)
8295

83-
UPSERT_RELATIONSHIP_QUERY = (
84-
"UNWIND $rows as row "
85-
"MATCH (start:__KGBuilder__), (end:__KGBuilder__) "
86-
"WHERE elementId(start) = row.start_node_element_id AND elementId(end) = row.end_node_element_id "
87-
"WITH start, end, row "
88-
"CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel "
89-
"WITH rel, row CALL { "
90-
"WITH rel, row WITH rel, row WHERE row.embedding_properties IS NOT NULL "
91-
"UNWIND keys(row.embedding_properties) as emb "
92-
"CALL db.create.setRelationshipVectorProperty(rel, emb, row.embedding_properties[emb]) "
93-
"} "
94-
"RETURN elementId(rel)"
95-
)
9696

97-
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE = (
98-
"UNWIND $rows as row "
99-
"MATCH (start:__KGBuilder__), (end:__KGBuilder__) "
100-
"WHERE elementId(start) = row.start_node_element_id AND elementId(end) = row.end_node_element_id "
101-
"WITH start, end, row "
102-
"CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel "
103-
"WITH rel, row CALL (rel, row) { "
104-
"WITH rel, row WITH rel, row WHERE row.embedding_properties IS NOT NULL "
105-
"UNWIND keys(row.embedding_properties) as emb "
106-
"CALL db.create.setRelationshipVectorProperty(rel, emb, row.embedding_properties[emb]) "
107-
"} "
108-
"RETURN elementId(rel)"
109-
)
97+
def upsert_relationship_query(support_variable_scope_clause: bool) -> str:
98+
"""Build the Cypher query to upsert a batch of relationships:
99+
- Create the new relationship:
100+
only one relationship of a specific type is allowed between the same two nodes
101+
- Set its properties
102+
- Set its embedding properties if any
103+
- Return the node elementId
104+
"""
105+
call_prefix = _call_subquery_syntax(
106+
support_variable_scope_clause, variable_list=["rel", "row"]
107+
)
108+
return (
109+
"UNWIND $rows as row "
110+
"MATCH (start:__KGBuilder__ {__tmp_internal_id: row.start_node_id}), "
111+
" (end:__KGBuilder__ {__tmp_internal_id: row.end_node_id}) "
112+
"WITH start, end, row "
113+
"CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel "
114+
"WITH rel, row "
115+
f"{call_prefix} "
116+
"WITH rel, row WHERE row.embedding_properties IS NOT NULL "
117+
"UNWIND keys(row.embedding_properties) as emb "
118+
"CALL db.create.setRelationshipVectorProperty(rel, emb, row.embedding_properties[emb]) "
119+
"} "
120+
"RETURN elementId(rel)"
121+
)
122+
123+
124+
def db_cleaning_query(support_variable_scope_clause: bool, batch_size: int) -> str:
125+
"""Removes the temporary __tmp_internal_id property from all nodes."""
126+
call_prefix = _call_subquery_syntax(
127+
support_variable_scope_clause, variable_list=["n"]
128+
)
129+
return (
130+
"MATCH (n:__KGBuilder__) "
131+
"WHERE n.__tmp_internal_id IS NOT NULL "
132+
f"{call_prefix} "
133+
" SET n.__tmp_internal_id = NULL "
134+
"} "
135+
f"IN TRANSACTIONS OF {batch_size} ROWS"
136+
)
137+
110138

111139
# Deprecated, remove along with upsert_vector
112140
UPSERT_VECTOR_ON_NODE_QUERY = (
@@ -150,13 +178,15 @@ def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str:
150178
Construct a cypher query for hybrid search.
151179
152180
Args:
153-
neo4j_version_is_5_23_or_above (bool): Whether or not the Neo4j version is 5.23 or above;
181+
neo4j_version_is_5_23_or_above (bool): Whether the Neo4j version is 5.23 or above;
154182
determines which call syntax is used.
155183
156184
Returns:
157185
str: The constructed Cypher query string.
158186
"""
159-
call_prefix = "CALL () { " if neo4j_version_is_5_23_or_above else "CALL { "
187+
call_prefix = _call_subquery_syntax(
188+
neo4j_version_is_5_23_or_above, variable_list=[]
189+
)
160190
query_body = (
161191
f"{NODE_VECTOR_INDEX_QUERY} "
162192
"WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score "

tests/e2e/docker-compose.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ services:
3232
- 7474:7474
3333
environment:
3434
NEO4J_AUTH: neo4j/password
35-
NEO4J_ACCEPT_LICENSE_AGREEMENT: "eval"
35+
NEO4J_ACCEPT_LICENSE_AGREEMENT: "yes"
3636
NEO4J_PLUGINS: "[\"apoc\"]"
37+
NEO4J_server_memory_heap_max__size: 6G
3738
qdrant:
3839
image: qdrant/qdrant
3940
ports:

0 commit comments

Comments
 (0)