14
14
# limitations under the License.
15
15
from __future__ import annotations
16
16
17
+ import asyncio
17
18
import logging
18
19
from abc import abstractmethod
19
- from typing import Literal , Optional
20
+ from typing import Any , Dict , Literal , Optional , Tuple
20
21
21
22
import neo4j
22
23
from pydantic import validate_call
27
28
Neo4jRelationship ,
28
29
)
29
30
from neo4j_genai .experimental .pipeline .component import Component , DataModel
30
- from neo4j_genai .indexes import upsert_vector , upsert_vector_on_relationship
31
+ from neo4j_genai .indexes import (
32
+ async_upsert_vector ,
33
+ async_upsert_vector_on_relationship ,
34
+ upsert_vector ,
35
+ upsert_vector_on_relationship ,
36
+ )
31
37
from neo4j_genai .neo4j_queries import UPSERT_NODE_QUERY , UPSERT_RELATIONSHIP_QUERY
32
38
33
39
logger = logging .getLogger (__name__ )
@@ -64,20 +70,21 @@ class Neo4jWriter(KGWriter):
64
70
Args:
65
71
driver (neo4j.driver): The Neo4j driver to connect to the database.
66
72
neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided.
73
+ max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM.
67
74
68
75
Example:
69
76
70
77
.. code-block:: python
71
78
72
- from neo4j import GraphDatabase
79
+ from neo4j import AsyncGraphDatabase
73
80
from neo4j_genai.experimental.components.kg_writer import Neo4jWriter
74
81
from neo4j_genai.experimental.pipeline import Pipeline
75
82
76
83
URI = "neo4j://localhost:7687"
77
84
AUTH = ("neo4j", "password")
78
85
DATABASE = "neo4j"
79
86
80
- driver = GraphDatabase .driver(URI, auth=AUTH, database=DATABASE)
87
+ driver = AsyncGraphDatabase .driver(URI, auth=AUTH, database=DATABASE)
81
88
writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE)
82
89
83
90
pipeline = Pipeline()
@@ -89,16 +96,13 @@ def __init__(
89
96
self ,
90
97
driver : neo4j .driver ,
91
98
neo4j_database : Optional [str ] = None ,
99
+ max_concurrency : int = 5 ,
92
100
):
93
101
self .driver = driver
94
102
self .neo4j_database = neo4j_database
103
+ self .max_concurrency = max_concurrency
95
104
96
- def _upsert_node (self , node : Neo4jNode ) -> None :
97
- """Upserts a single node into the Neo4j database."
98
-
99
- Args:
100
- node (Neo4jNode): The node to upsert into the database.
101
- """
105
+ def _get_node_query (self , node : Neo4jNode ) -> Tuple [str , Dict [str , Any ]]:
102
106
# Create the initial node
103
107
parameters = {"id" : node .id }
104
108
if node .properties :
@@ -107,6 +111,15 @@ def _upsert_node(self, node: Neo4jNode) -> None:
107
111
"{" + ", " .join (f"{ key } : ${ key } " for key in parameters .keys ()) + "}"
108
112
)
109
113
query = UPSERT_NODE_QUERY .format (label = node .label , properties = properties )
114
+ return query , parameters
115
+
116
+ def _upsert_node (self , node : Neo4jNode ) -> None :
117
+ """Upserts a single node into the Neo4j database."
118
+
119
+ Args:
120
+ node (Neo4jNode): The node to upsert into the database.
121
+ """
122
+ query , parameters = self ._get_node_query (node )
110
123
result = self .driver .execute_query (query , parameters_ = parameters )
111
124
node_id = result .records [0 ]["elementID(n)" ]
112
125
# Add the embedding properties to the node
@@ -120,12 +133,32 @@ def _upsert_node(self, node: Neo4jNode) -> None:
120
133
neo4j_database = self .neo4j_database ,
121
134
)
122
135
123
- def _upsert_relationship (self , rel : Neo4jRelationship ) -> None :
124
- """Upserts a single relationship into the Neo4j database.
136
+ async def _async_upsert_node (
137
+ self ,
138
+ node : Neo4jNode ,
139
+ sem : asyncio .Semaphore ,
140
+ ) -> None :
141
+ """Asynchronously upserts a single node into the Neo4j database."
125
142
126
143
Args:
127
- rel (Neo4jRelationship ): The relationship to upsert into the database.
144
+ node (Neo4jNode ): The node to upsert into the database.
128
145
"""
146
+ async with sem :
147
+ query , parameters = self ._get_node_query (node )
148
+ result = await self .driver .execute_query (query , parameters_ = parameters )
149
+ node_id = result .records [0 ]["elementID(n)" ]
150
+ # Add the embedding properties to the node
151
+ if node .embedding_properties :
152
+ for prop , vector in node .embedding_properties .items ():
153
+ await async_upsert_vector (
154
+ driver = self .driver ,
155
+ node_id = node_id ,
156
+ embedding_property = prop ,
157
+ vector = vector ,
158
+ neo4j_database = self .neo4j_database ,
159
+ )
160
+
161
+ def _get_rel_query (self , rel : Neo4jRelationship ) -> Tuple [str , Dict [str , Any ]]:
129
162
# Create the initial relationship
130
163
parameters = {
131
164
"start_node_id" : rel .start_node_id ,
@@ -142,6 +175,15 @@ def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
142
175
type = rel .type ,
143
176
properties = properties ,
144
177
)
178
+ return query , parameters
179
+
180
+ def _upsert_relationship (self , rel : Neo4jRelationship ) -> None :
181
+ """Upserts a single relationship into the Neo4j database.
182
+
183
+ Args:
184
+ rel (Neo4jRelationship): The relationship to upsert into the database.
185
+ """
186
+ query , parameters = self ._get_rel_query (rel )
145
187
result = self .driver .execute_query (query , parameters_ = parameters )
146
188
rel_id = result .records [0 ]["elementID(r)" ]
147
189
# Add the embedding properties to the relationship
@@ -155,6 +197,29 @@ def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
155
197
neo4j_database = self .neo4j_database ,
156
198
)
157
199
200
+ async def _async_upsert_relationship (
201
+ self , rel : Neo4jRelationship , sem : asyncio .Semaphore
202
+ ) -> None :
203
+ """Asynchronously upserts a single relationship into the Neo4j database.
204
+
205
+ Args:
206
+ rel (Neo4jRelationship): The relationship to upsert into the database.
207
+ """
208
+ async with sem :
209
+ query , parameters = self ._get_rel_query (rel )
210
+ result = await self .driver .execute_query (query , parameters_ = parameters )
211
+ rel_id = result .records [0 ]["elementID(r)" ]
212
+ # Add the embedding properties to the relationship
213
+ if rel .embedding_properties :
214
+ for prop , vector in rel .embedding_properties .items ():
215
+ await async_upsert_vector_on_relationship (
216
+ driver = self .driver ,
217
+ rel_id = rel_id ,
218
+ embedding_property = prop ,
219
+ vector = vector ,
220
+ neo4j_database = self .neo4j_database ,
221
+ )
222
+
158
223
@validate_call
159
224
async def run (self , graph : Neo4jGraph ) -> KGWriterModel :
160
225
"""Upserts a knowledge graph into a Neo4j database.
@@ -163,11 +228,24 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
163
228
graph (Neo4jGraph): The knowledge graph to upsert into the database.
164
229
"""
165
230
try :
166
- for node in graph .nodes :
167
- self ._upsert_node (node )
168
-
169
- for rel in graph .relationships :
170
- self ._upsert_relationship (rel )
231
+ if isinstance (self .driver , neo4j .AsyncDriver ):
232
+ sem = asyncio .Semaphore (self .max_concurrency )
233
+ node_tasks = [
234
+ self ._async_upsert_node (node , sem ) for node in graph .nodes
235
+ ]
236
+ await asyncio .gather (* node_tasks )
237
+
238
+ rel_tasks = [
239
+ self ._async_upsert_relationship (rel , sem )
240
+ for rel in graph .relationships
241
+ ]
242
+ await asyncio .gather (* rel_tasks )
243
+ else :
244
+ for node in graph .nodes :
245
+ self ._upsert_node (node )
246
+
247
+ for rel in graph .relationships :
248
+ self ._upsert_relationship (rel )
171
249
172
250
return KGWriterModel (status = "SUCCESS" )
173
251
except neo4j .exceptions .ClientError as e :
0 commit comments