15
15
from __future__ import annotations
16
16
17
17
import asyncio
18
+ import inspect
18
19
import logging
19
20
from abc import abstractmethod
20
21
from typing import Any , Dict , Literal , Optional , Tuple
28
29
Neo4jRelationship ,
29
30
)
30
31
from neo4j_graphrag .experimental .pipeline .component import Component , DataModel
31
- from neo4j_graphrag .indexes import (
32
- async_upsert_vector ,
33
- async_upsert_vector_on_relationship ,
34
- upsert_vector ,
35
- upsert_vector_on_relationship ,
36
- )
37
32
from neo4j_graphrag .neo4j_queries import UPSERT_NODE_QUERY , UPSERT_RELATIONSHIP_QUERY
38
33
39
34
logger = logging .getLogger (__name__ )
@@ -102,15 +97,26 @@ def __init__(
102
97
self .neo4j_database = neo4j_database
103
98
self .max_concurrency = max_concurrency
104
99
100
+ def _db_setup (self ) -> None :
101
+ # create index on __Entity__.id
102
+ self .driver .execute_query (
103
+ "CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__Entity__) ON (n.id)"
104
+ )
105
+
106
+ async def _async_db_setup (self ) -> None :
107
+ # create index on __Entity__.id
108
+ await self .driver .execute_query (
109
+ "CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__Entity__) ON (n.id)"
110
+ )
111
+
105
112
def _get_node_query (self , node : Neo4jNode ) -> Tuple [str , Dict [str , Any ]]:
106
113
# Create the initial node
107
- parameters = {"id" : node .id }
108
- if node .properties :
109
- parameters .update (node .properties )
110
- properties = (
111
- "{" + ", " .join (f"{ key } : ${ key } " for key in parameters .keys ()) + "}"
112
- )
113
- query = UPSERT_NODE_QUERY .format (label = node .label , properties = properties )
114
+ parameters = {
115
+ "id" : node .id ,
116
+ "properties" : node .properties or {},
117
+ "embeddings" : node .embedding_properties ,
118
+ }
119
+ query = UPSERT_NODE_QUERY .format (label = node .label )
114
120
return query , parameters
115
121
116
122
def _upsert_node (self , node : Neo4jNode ) -> None :
@@ -120,18 +126,7 @@ def _upsert_node(self, node: Neo4jNode) -> None:
120
126
node (Neo4jNode): The node to upsert into the database.
121
127
"""
122
128
query , parameters = self ._get_node_query (node )
123
- result = self .driver .execute_query (query , parameters_ = parameters )
124
- node_id = result .records [0 ]["elementID(n)" ]
125
- # Add the embedding properties to the node
126
- if node .embedding_properties :
127
- for prop , vector in node .embedding_properties .items ():
128
- upsert_vector (
129
- driver = self .driver ,
130
- node_id = node_id ,
131
- embedding_property = prop ,
132
- vector = vector ,
133
- neo4j_database = self .neo4j_database ,
134
- )
129
+ self .driver .execute_query (query , parameters_ = parameters )
135
130
136
131
async def _async_upsert_node (
137
132
self ,
@@ -145,35 +140,18 @@ async def _async_upsert_node(
145
140
"""
146
141
async with sem :
147
142
query , parameters = self ._get_node_query (node )
148
- result = await self .driver .execute_query (query , parameters_ = parameters )
149
- node_id = result .records [0 ]["elementID(n)" ]
150
- # Add the embedding properties to the node
151
- if node .embedding_properties :
152
- for prop , vector in node .embedding_properties .items ():
153
- await async_upsert_vector (
154
- driver = self .driver ,
155
- node_id = node_id ,
156
- embedding_property = prop ,
157
- vector = vector ,
158
- neo4j_database = self .neo4j_database ,
159
- )
143
+ await self .driver .execute_query (query , parameters_ = parameters )
160
144
161
145
def _get_rel_query (self , rel : Neo4jRelationship ) -> Tuple [str , Dict [str , Any ]]:
162
146
# Create the initial relationship
163
147
parameters = {
164
148
"start_node_id" : rel .start_node_id ,
165
149
"end_node_id" : rel .end_node_id ,
150
+ "properties" : rel .properties or {},
151
+ "embeddings" : rel .embedding_properties ,
166
152
}
167
- if rel .properties :
168
- properties = (
169
- "{" + ", " .join (f"{ key } : ${ key } " for key in rel .properties .keys ()) + "}"
170
- )
171
- parameters .update (rel .properties )
172
- else :
173
- properties = "{}"
174
153
query = UPSERT_RELATIONSHIP_QUERY .format (
175
154
type = rel .type ,
176
- properties = properties ,
177
155
)
178
156
return query , parameters
179
157
@@ -184,18 +162,7 @@ def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
184
162
rel (Neo4jRelationship): The relationship to upsert into the database.
185
163
"""
186
164
query , parameters = self ._get_rel_query (rel )
187
- result = self .driver .execute_query (query , parameters_ = parameters )
188
- rel_id = result .records [0 ]["elementID(r)" ]
189
- # Add the embedding properties to the relationship
190
- if rel .embedding_properties :
191
- for prop , vector in rel .embedding_properties .items ():
192
- upsert_vector_on_relationship (
193
- driver = self .driver ,
194
- rel_id = rel_id ,
195
- embedding_property = prop ,
196
- vector = vector ,
197
- neo4j_database = self .neo4j_database ,
198
- )
165
+ self .driver .execute_query (query , parameters_ = parameters )
199
166
200
167
async def _async_upsert_relationship (
201
168
self , rel : Neo4jRelationship , sem : asyncio .Semaphore
@@ -207,18 +174,7 @@ async def _async_upsert_relationship(
207
174
"""
208
175
async with sem :
209
176
query , parameters = self ._get_rel_query (rel )
210
- result = await self .driver .execute_query (query , parameters_ = parameters )
211
- rel_id = result .records [0 ]["elementID(r)" ]
212
- # Add the embedding properties to the relationship
213
- if rel .embedding_properties :
214
- for prop , vector in rel .embedding_properties .items ():
215
- await async_upsert_vector_on_relationship (
216
- driver = self .driver ,
217
- rel_id = rel_id ,
218
- embedding_property = prop ,
219
- vector = vector ,
220
- neo4j_database = self .neo4j_database ,
221
- )
177
+ await self .driver .execute_query (query , parameters_ = parameters )
222
178
223
179
@validate_call
224
180
async def run (self , graph : Neo4jGraph ) -> KGWriterModel :
@@ -228,7 +184,8 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
228
184
graph (Neo4jGraph): The knowledge graph to upsert into the database.
229
185
"""
230
186
try :
231
- if isinstance (self .driver , neo4j .AsyncDriver ):
187
+ if inspect .iscoroutinefunction (self .driver .execute_query ):
188
+ await self ._async_db_setup ()
232
189
sem = asyncio .Semaphore (self .max_concurrency )
233
190
node_tasks = [
234
191
self ._async_upsert_node (node , sem ) for node in graph .nodes
@@ -241,6 +198,8 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
241
198
]
242
199
await asyncio .gather (* rel_tasks )
243
200
else :
201
+ self ._db_setup ()
202
+
244
203
for node in graph .nodes :
245
204
self ._upsert_node (node )
246
205
0 commit comments