Skip to content

Commit 57529d4

Browse files
authored
Remove async driver support from the KG creation pipeline (#201)
* Removed async driver support from Neo4jWriter * Removes neo4j_graphrag.utils.execute_query * Actually removes neo4j_graphrag.utils.execute_query * Updated CHANGELOG * Removed references to max_concurrency
1 parent 9391662 commit 57529d4

File tree

13 files changed

+39
-271
lines changed

13 files changed

+39
-271
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
### Changed
1111
- Vector and Hybrid retrievers used with `return_properties` now also return the node labels (`nodeLabels`) and the node's element ID (`id`).
1212
- `HybridRetriever` now filters out the embedding property index in `self.vector_index_name` from the retriever result by default.
13+
- Removed support for neo4j.AsyncDriver in the KG creation pipeline, affecting Neo4jWriter and related components.
14+
- Updated examples and unit tests to reflect the removal of async driver support.
1315

1416

1517
## 1.1.0

docs/source/user_guide_kg_builder.rst

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,10 +433,8 @@ to a Neo4j database:
433433
graph = Neo4jGraph(nodes=[], relationships=[])
434434
await writer.run(graph)
435435
436-
To improve insert performances, it is possible to act on two parameters:
437-
438-
- `batch_size`: the number of nodes/relationships to be processed in each batch (default is 1000).
439-
- `max_concurrency`: the max number of concurrent queries (default is 5).
436+
Adjust the batch_size parameter of `Neo4jWriter` to optimize insert performance.
437+
This parameter controls the number of nodes or relationships inserted per batch, with a default value of 1000.
440438

441439
See :ref:`neo4jgraph`.
442440

examples/customize/build_graph/components/resolvers/custom_resolver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
a specific signature for the run method, which makes it very flexible.
33
"""
44

5-
from typing import Any, Optional, Union
5+
from typing import Any, Optional
66

77
import neo4j
88
from neo4j_graphrag.experimental.components.resolver import EntityResolver
@@ -12,7 +12,7 @@
1212
class MyEntityResolver(EntityResolver):
1313
def __init__(
1414
self,
15-
driver: Union[neo4j.Driver, neo4j.AsyncDriver],
15+
driver: neo4j.Driver,
1616
filter_query: Optional[str] = None,
1717
) -> None:
1818
super().__init__(driver, filter_query)

examples/customize/build_graph/components/writers/neo4j_writer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@ async def main(driver: neo4j.Driver, graph: Neo4jGraph) -> KGWriterModel:
1111
driver,
1212
# optionally, configure the neo4j database
1313
# neo4j_database="neo4j",
14-
# you can tune batch_size and max_concurrency to
14+
# you can tune batch_size to
1515
# improve speed
1616
# batch_size=1000,
17-
# max_concurrency=5,
1817
)
1918
result = await writer.run(graph=graph)
2019
return result

examples/customize/build_graph/pipeline/kg_builder_from_pdf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040

4141
async def define_and_run_pipeline(
42-
neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface
42+
neo4j_driver: neo4j.Driver, llm: LLMInterface
4343
) -> PipelineResult:
4444
from neo4j_graphrag.experimental.pipeline import Pipeline
4545

@@ -131,11 +131,11 @@ async def main() -> PipelineResult:
131131
"response_format": {"type": "json_object"},
132132
},
133133
)
134-
driver = neo4j.AsyncGraphDatabase.driver(
134+
driver = neo4j.GraphDatabase.driver(
135135
"bolt://localhost:7687", auth=("neo4j", "password")
136136
)
137137
res = await define_and_run_pipeline(driver, llm)
138-
await driver.close()
138+
driver.close()
139139
await llm.async_client.close()
140140
return res
141141

examples/customize/build_graph/pipeline/kg_builder_from_text.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040

4141
async def define_and_run_pipeline(
42-
neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface
42+
neo4j_driver: neo4j.Driver, llm: LLMInterface
4343
) -> PipelineResult:
4444
"""This is where we define and run the KG builder pipeline, instantiating a few
4545
components:
@@ -148,11 +148,11 @@ async def main() -> PipelineResult:
148148
"response_format": {"type": "json_object"},
149149
},
150150
)
151-
driver = neo4j.AsyncGraphDatabase.driver(
151+
driver = neo4j.GraphDatabase.driver(
152152
"bolt://localhost:7687", auth=("neo4j", "password")
153153
)
154154
res = await define_and_run_pipeline(driver, llm)
155-
await driver.close()
155+
driver.close()
156156
await llm.async_client.close()
157157
return res
158158

examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141

4242
async def define_and_run_pipeline(
43-
neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface
43+
neo4j_driver: neo4j.Driver, llm: LLMInterface
4444
) -> None:
4545
"""This is where we define and run the KG builder pipeline, instantiating a few
4646
components:
@@ -144,11 +144,11 @@ async def main() -> None:
144144
"response_format": {"type": "json_object"},
145145
},
146146
)
147-
driver = neo4j.AsyncGraphDatabase.driver(
147+
driver = neo4j.GraphDatabase.driver(
148148
"bolt://localhost:7687", auth=("neo4j", "password")
149149
)
150150
await define_and_run_pipeline(driver, llm)
151-
await driver.close()
151+
driver.close()
152152
await llm.async_client.close()
153153

154154

examples/kg_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646

4747
async def define_and_run_pipeline(
48-
neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface
48+
neo4j_driver: neo4j.Driver, llm: LLMInterface
4949
) -> PipelineResult:
5050
from neo4j_graphrag.experimental.pipeline import Pipeline
5151

@@ -137,11 +137,11 @@ async def main() -> PipelineResult:
137137
"response_format": {"type": "json_object"},
138138
},
139139
)
140-
driver = neo4j.AsyncGraphDatabase.driver(
140+
driver = neo4j.GraphDatabase.driver(
141141
"bolt://localhost:7687", auth=("neo4j", "password")
142142
)
143143
res = await define_and_run_pipeline(driver, llm)
144-
await driver.close()
144+
driver.close()
145145
await llm.async_client.close()
146146
return res
147147

src/neo4j_graphrag/experimental/components/kg_writer.py

Lines changed: 11 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
import asyncio
18-
import inspect
1917
import logging
2018
from abc import abstractmethod
2119
from typing import Any, Generator, Literal, Optional
@@ -87,21 +85,21 @@ class Neo4jWriter(KGWriter):
8785
Args:
8886
driver (neo4j.driver): The Neo4j driver to connect to the database.
8987
neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided.
90-
max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM.
88+
batch_size (int): The number of nodes or relationships to write to the database in a batch. Defaults to 1000.
9189
9290
Example:
9391
9492
.. code-block:: python
9593
96-
from neo4j import AsyncGraphDatabase
94+
from neo4j import GraphDatabase
9795
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
9896
from neo4j_graphrag.experimental.pipeline import Pipeline
9997
10098
URI = "neo4j://localhost:7687"
10199
AUTH = ("neo4j", "password")
102100
DATABASE = "neo4j"
103101
104-
driver = AsyncGraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
102+
driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
105103
writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE)
106104
107105
pipeline = Pipeline()
@@ -111,15 +109,13 @@ class Neo4jWriter(KGWriter):
111109

112110
def __init__(
113111
self,
114-
driver: neo4j.driver,
112+
driver: neo4j.Driver,
115113
neo4j_database: Optional[str] = None,
116114
batch_size: int = 1000,
117-
max_concurrency: int = 5,
118115
):
119116
self.driver = driver
120117
self.neo4j_database = neo4j_database
121118
self.batch_size = batch_size
122-
self.max_concurrency = max_concurrency
123119
self.is_version_5_23_or_above = self._check_if_version_5_23_or_above()
124120

125121
def _db_setup(self) -> None:
@@ -129,13 +125,6 @@ def _db_setup(self) -> None:
129125
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)"
130126
)
131127

132-
async def _async_db_setup(self) -> None:
133-
# create index on __Entity__.id
134-
# used when creating the relationships
135-
await self.driver.execute_query(
136-
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)"
137-
)
138-
139128
@staticmethod
140129
def _nodes_to_rows(
141130
nodes: list[Neo4jNode], lexical_graph_config: LexicalGraphConfig
@@ -166,23 +155,6 @@ def _upsert_nodes(
166155
else:
167156
self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)
168157

169-
async def _async_upsert_nodes(
170-
self,
171-
nodes: list[Neo4jNode],
172-
lexical_graph_config: LexicalGraphConfig,
173-
sem: asyncio.Semaphore,
174-
) -> None:
175-
"""Asynchronously upserts a single node into the Neo4j database."
176-
177-
Args:
178-
nodes (list[Neo4jNode]): The nodes batch to upsert into the database.
179-
"""
180-
async with sem:
181-
parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)}
182-
await self.driver.execute_query(
183-
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
184-
)
185-
186158
def _get_version(self) -> tuple[int, ...]:
187159
records, _, _ = self.driver.execute_query(
188160
"CALL dbms.components()", database_=self.neo4j_database
@@ -220,26 +192,6 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
220192
else:
221193
self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters)
222194

223-
async def _async_upsert_relationships(
224-
self, rels: list[Neo4jRelationship], sem: asyncio.Semaphore
225-
) -> None:
226-
"""Asynchronously upserts a single relationship into the Neo4j database.
227-
228-
Args:
229-
rels (list[Neo4jRelationship]): The relationships batch to upsert into the database.
230-
"""
231-
async with sem:
232-
parameters = {"rows": [rel.model_dump() for rel in rels]}
233-
if self.is_version_5_23_or_above:
234-
await self.driver.execute_query(
235-
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
236-
parameters_=parameters,
237-
)
238-
else:
239-
await self.driver.execute_query(
240-
UPSERT_RELATIONSHIP_QUERY, parameters_=parameters
241-
)
242-
243195
@validate_call
244196
async def run(
245197
self,
@@ -253,28 +205,13 @@ async def run(
253205
lexical_graph_config (LexicalGraphConfig):
254206
"""
255207
try:
256-
if inspect.iscoroutinefunction(self.driver.execute_query):
257-
await self._async_db_setup()
258-
sem = asyncio.Semaphore(self.max_concurrency)
259-
node_tasks = [
260-
self._async_upsert_nodes(batch, lexical_graph_config, sem)
261-
for batch in batched(graph.nodes, self.batch_size)
262-
]
263-
await asyncio.gather(*node_tasks)
264-
265-
rel_tasks = [
266-
self._async_upsert_relationships(batch, sem)
267-
for batch in batched(graph.relationships, self.batch_size)
268-
]
269-
await asyncio.gather(*rel_tasks)
270-
else:
271-
self._db_setup()
272-
273-
for batch in batched(graph.nodes, self.batch_size):
274-
self._upsert_nodes(batch, lexical_graph_config)
275-
276-
for batch in batched(graph.relationships, self.batch_size):
277-
self._upsert_relationships(batch)
208+
self._db_setup()
209+
210+
for batch in batched(graph.nodes, self.batch_size):
211+
self._upsert_nodes(batch, lexical_graph_config)
212+
213+
for batch in batched(graph.relationships, self.batch_size):
214+
self._upsert_relationships(batch)
278215

279216
return KGWriterModel(
280217
status="SUCCESS",

src/neo4j_graphrag/experimental/components/resolver.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import abc
16-
from typing import Any, Optional, Union
16+
from typing import Any, Optional
1717

1818
import neo4j
1919

2020
from neo4j_graphrag.experimental.components.types import ResolutionStats
2121
from neo4j_graphrag.experimental.pipeline import Component
22-
from neo4j_graphrag.utils import execute_query
2322

2423

2524
class EntityResolver(Component, abc.ABC):
@@ -32,7 +31,7 @@ class EntityResolver(Component, abc.ABC):
3231

3332
def __init__(
3433
self,
35-
driver: Union[neo4j.Driver, neo4j.AsyncDriver],
34+
driver: neo4j.Driver,
3635
filter_query: Optional[str] = None,
3736
) -> None:
3837
self.driver = driver
@@ -56,22 +55,22 @@ class SinglePropertyExactMatchResolver(EntityResolver):
5655
5756
.. code-block:: python
5857
59-
from neo4j import AsyncGraphDatabase
58+
from neo4j import GraphDatabase
6059
from neo4j_graphrag.experimental.components.resolver import SinglePropertyExactMatchResolver
6160
6261
URI = "neo4j://localhost:7687"
6362
AUTH = ("neo4j", "password")
6463
DATABASE = "neo4j"
6564
66-
driver = AsyncGraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
65+
driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
6766
resolver = SinglePropertyExactMatchResolver(driver=driver, neo4j_database=DATABASE)
6867
await resolver.run() # no expected parameters
6968
7069
"""
7170

7271
def __init__(
7372
self,
74-
driver: Union[neo4j.Driver, neo4j.AsyncDriver],
73+
driver: neo4j.Driver,
7574
filter_query: Optional[str] = None,
7675
resolve_property: str = "name",
7776
neo4j_database: Optional[str] = None,
@@ -94,11 +93,7 @@ async def run(self) -> ResolutionStats:
9493
if self.filter_query:
9594
match_query += self.filter_query
9695
stat_query = f"{match_query} RETURN count(entity) as c"
97-
records, _, _ = await execute_query(
98-
self.driver,
99-
stat_query,
100-
database_=self.database,
101-
)
96+
records, _, _ = self.driver.execute_query(stat_query, database_=self.database)
10297
number_of_nodes_to_resolve = records[0].get("c")
10398
if number_of_nodes_to_resolve == 0:
10499
return ResolutionStats(
@@ -130,10 +125,8 @@ async def run(self) -> ResolutionStats:
130125
"YIELD node "
131126
"RETURN count(node) as c "
132127
)
133-
records, _, _ = await execute_query(
134-
self.driver,
135-
merge_nodes_query,
136-
database_=self.database,
128+
records, _, _ = self.driver.execute_query(
129+
merge_nodes_query, database_=self.database
137130
)
138131
number_of_created_nodes = records[0].get("c")
139132
return ResolutionStats(

src/neo4j_graphrag/utils.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,11 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import Any, Optional
18-
19-
import neo4j
17+
from typing import Optional
2018

2119

2220
def validate_search_query_input(
2321
query_text: Optional[str] = None, query_vector: Optional[list[float]] = None
2422
) -> None:
2523
if not (bool(query_vector) ^ bool(query_text)):
2624
raise ValueError("You must provide exactly one of query_vector or query_text.")
27-
28-
29-
async def execute_query(
30-
driver: neo4j.Driver | neo4j.AsyncDriver, query: str, **kwargs: Any
31-
) -> Any:
32-
if isinstance(driver, neo4j.AsyncDriver):
33-
result = await driver.execute_query(query, **kwargs)
34-
else:
35-
result = driver.execute_query(query, **kwargs)
36-
return result

0 commit comments

Comments
 (0)