Skip to content

Commit 1a928a5

Browse files
authored
KG Creation Pipeline Speed Up (#117)
* Made LLM calls concurrent in KG creation pipeline * Renamed test_kb_builder_pipeline_e2e.py to test_kg_builder_pipeline_e2e.py * Added an asynchronous Neo4jWriter * Merged async and non-async Neo4j writers * Fixed neo4j writer docstring * Refactored Neo4jWriter
1 parent 060d709 commit 1a928a5

File tree

9 files changed

+313
-56
lines changed

9 files changed

+313
-56
lines changed

docs/source/api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ Database Interaction
213213

214214
.. autofunction:: neo4j_genai.indexes.upsert_vector_on_relationship
215215

216+
.. autofunction:: neo4j_genai.indexes.async_upsert_vector
217+
218+
.. autofunction:: neo4j_genai.indexes.async_upsert_vector_on_relationship
216219

217220
******
218221
Errors

src/neo4j_genai/experimental/components/entity_relation_extractor.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ class LLMEntityRelationExtractor(EntityRelationExtractor):
232232
prompt_template (ERExtractionTemplate | str): A custom prompt template to use for extraction.
233233
create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True.
234234
on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error.
235+
max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM.
235236
236237
Example:
237238
@@ -255,9 +256,11 @@ def __init__(
255256
prompt_template: ERExtractionTemplate | str = ERExtractionTemplate(),
256257
create_lexical_graph: bool = True,
257258
on_error: OnError = OnError.RAISE,
259+
max_concurrency: int = 5,
258260
) -> None:
259261
super().__init__(on_error=on_error, create_lexical_graph=create_lexical_graph)
260262
self.llm = llm # with response_format={ "type": "json_object" },
263+
self.max_concurrency = max_concurrency
261264
if isinstance(prompt_template, str):
262265
template = PromptTemplate(prompt_template, expected_inputs=[])
263266
else:
@@ -271,7 +274,7 @@ async def extract_for_chunk(
271274
prompt = self.prompt_template.format(
272275
text=chunk.text, schema=schema.model_dump(), examples=examples
273276
)
274-
llm_result = self.llm.invoke(prompt)
277+
llm_result = await self.llm.ainvoke(prompt)
275278
try:
276279
result = json.loads(llm_result.content)
277280
except json.JSONDecodeError:
@@ -322,12 +325,20 @@ def combine_chunk_graphs(self, chunk_graphs: List[Neo4jGraph]) -> Neo4jGraph:
322325
return graph
323326

324327
async def run_for_chunk(
325-
self, schema: SchemaConfig, examples: str, chunk_index: int, chunk: TextChunk
328+
self,
329+
schema: SchemaConfig,
330+
examples: str,
331+
chunk_index: int,
332+
chunk: TextChunk,
333+
sem: asyncio.Semaphore,
326334
) -> Neo4jGraph:
327335
"""Run extraction and post processing for a single chunk"""
328-
chunk_graph = await self.extract_for_chunk(schema, examples, chunk_index, chunk)
329-
await self.post_process_chunk(chunk_graph, chunk_index, chunk)
330-
return chunk_graph
336+
async with sem:
337+
chunk_graph = await self.extract_for_chunk(
338+
schema, examples, chunk_index, chunk
339+
)
340+
await self.post_process_chunk(chunk_graph, chunk_index, chunk)
341+
return chunk_graph
331342

332343
@validate_call
333344
async def run(
@@ -341,8 +352,9 @@ async def run(
341352
schema = schema or SchemaConfig(entities={}, relations={}, potential_schema=[])
342353
examples = examples or ""
343354
self._id_prefix = str(datetime.now().timestamp())
355+
sem = asyncio.Semaphore(self.max_concurrency)
344356
tasks = [
345-
self.run_for_chunk(schema, examples, chunk_index, chunk)
357+
self.run_for_chunk(schema, examples, chunk_index, chunk, sem)
346358
for chunk_index, chunk in enumerate(chunks.chunks)
347359
]
348360
chunk_graphs = await asyncio.gather(*tasks)

src/neo4j_genai/experimental/components/kg_writer.py

Lines changed: 96 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import asyncio
1718
import logging
1819
from abc import abstractmethod
19-
from typing import Literal, Optional
20+
from typing import Any, Dict, Literal, Optional, Tuple
2021

2122
import neo4j
2223
from pydantic import validate_call
@@ -27,7 +28,12 @@
2728
Neo4jRelationship,
2829
)
2930
from neo4j_genai.experimental.pipeline.component import Component, DataModel
30-
from neo4j_genai.indexes import upsert_vector, upsert_vector_on_relationship
31+
from neo4j_genai.indexes import (
32+
async_upsert_vector,
33+
async_upsert_vector_on_relationship,
34+
upsert_vector,
35+
upsert_vector_on_relationship,
36+
)
3137
from neo4j_genai.neo4j_queries import UPSERT_NODE_QUERY, UPSERT_RELATIONSHIP_QUERY
3238

3339
logger = logging.getLogger(__name__)
@@ -64,20 +70,21 @@ class Neo4jWriter(KGWriter):
6470
Args:
6571
driver (neo4j.driver): The Neo4j driver to connect to the database.
6672
neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided.
73+
max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM.
6774
6875
Example:
6976
7077
.. code-block:: python
7178
72-
from neo4j import GraphDatabase
79+
from neo4j import AsyncGraphDatabase
7380
from neo4j_genai.experimental.components.kg_writer import Neo4jWriter
7481
from neo4j_genai.experimental.pipeline import Pipeline
7582
7683
URI = "neo4j://localhost:7687"
7784
AUTH = ("neo4j", "password")
7885
DATABASE = "neo4j"
7986
80-
driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
87+
driver = AsyncGraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
8188
writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE)
8289
8390
pipeline = Pipeline()
@@ -89,16 +96,13 @@ def __init__(
8996
self,
9097
driver: neo4j.driver,
9198
neo4j_database: Optional[str] = None,
99+
max_concurrency: int = 5,
92100
):
93101
self.driver = driver
94102
self.neo4j_database = neo4j_database
103+
self.max_concurrency = max_concurrency
95104

96-
def _upsert_node(self, node: Neo4jNode) -> None:
97-
"""Upserts a single node into the Neo4j database."
98-
99-
Args:
100-
node (Neo4jNode): The node to upsert into the database.
101-
"""
105+
def _get_node_query(self, node: Neo4jNode) -> Tuple[str, Dict[str, Any]]:
102106
# Create the initial node
103107
parameters = {"id": node.id}
104108
if node.properties:
@@ -107,6 +111,15 @@ def _upsert_node(self, node: Neo4jNode) -> None:
107111
"{" + ", ".join(f"{key}: ${key}" for key in parameters.keys()) + "}"
108112
)
109113
query = UPSERT_NODE_QUERY.format(label=node.label, properties=properties)
114+
return query, parameters
115+
116+
def _upsert_node(self, node: Neo4jNode) -> None:
117+
"""Upserts a single node into the Neo4j database."
118+
119+
Args:
120+
node (Neo4jNode): The node to upsert into the database.
121+
"""
122+
query, parameters = self._get_node_query(node)
110123
result = self.driver.execute_query(query, parameters_=parameters)
111124
node_id = result.records[0]["elementID(n)"]
112125
# Add the embedding properties to the node
@@ -120,12 +133,32 @@ def _upsert_node(self, node: Neo4jNode) -> None:
120133
neo4j_database=self.neo4j_database,
121134
)
122135

123-
def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
124-
"""Upserts a single relationship into the Neo4j database.
136+
async def _async_upsert_node(
137+
self,
138+
node: Neo4jNode,
139+
sem: asyncio.Semaphore,
140+
) -> None:
141+
"""Asynchronously upserts a single node into the Neo4j database."
125142
126143
Args:
127-
rel (Neo4jRelationship): The relationship to upsert into the database.
144+
node (Neo4jNode): The node to upsert into the database.
128145
"""
146+
async with sem:
147+
query, parameters = self._get_node_query(node)
148+
result = await self.driver.execute_query(query, parameters_=parameters)
149+
node_id = result.records[0]["elementID(n)"]
150+
# Add the embedding properties to the node
151+
if node.embedding_properties:
152+
for prop, vector in node.embedding_properties.items():
153+
await async_upsert_vector(
154+
driver=self.driver,
155+
node_id=node_id,
156+
embedding_property=prop,
157+
vector=vector,
158+
neo4j_database=self.neo4j_database,
159+
)
160+
161+
def _get_rel_query(self, rel: Neo4jRelationship) -> Tuple[str, Dict[str, Any]]:
129162
# Create the initial relationship
130163
parameters = {
131164
"start_node_id": rel.start_node_id,
@@ -142,6 +175,15 @@ def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
142175
type=rel.type,
143176
properties=properties,
144177
)
178+
return query, parameters
179+
180+
def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
181+
"""Upserts a single relationship into the Neo4j database.
182+
183+
Args:
184+
rel (Neo4jRelationship): The relationship to upsert into the database.
185+
"""
186+
query, parameters = self._get_rel_query(rel)
145187
result = self.driver.execute_query(query, parameters_=parameters)
146188
rel_id = result.records[0]["elementID(r)"]
147189
# Add the embedding properties to the relationship
@@ -155,6 +197,29 @@ def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
155197
neo4j_database=self.neo4j_database,
156198
)
157199

200+
async def _async_upsert_relationship(
201+
self, rel: Neo4jRelationship, sem: asyncio.Semaphore
202+
) -> None:
203+
"""Asynchronously upserts a single relationship into the Neo4j database.
204+
205+
Args:
206+
rel (Neo4jRelationship): The relationship to upsert into the database.
207+
"""
208+
async with sem:
209+
query, parameters = self._get_rel_query(rel)
210+
result = await self.driver.execute_query(query, parameters_=parameters)
211+
rel_id = result.records[0]["elementID(r)"]
212+
# Add the embedding properties to the relationship
213+
if rel.embedding_properties:
214+
for prop, vector in rel.embedding_properties.items():
215+
await async_upsert_vector_on_relationship(
216+
driver=self.driver,
217+
rel_id=rel_id,
218+
embedding_property=prop,
219+
vector=vector,
220+
neo4j_database=self.neo4j_database,
221+
)
222+
158223
@validate_call
159224
async def run(self, graph: Neo4jGraph) -> KGWriterModel:
160225
"""Upserts a knowledge graph into a Neo4j database.
@@ -163,11 +228,24 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
163228
graph (Neo4jGraph): The knowledge graph to upsert into the database.
164229
"""
165230
try:
166-
for node in graph.nodes:
167-
self._upsert_node(node)
168-
169-
for rel in graph.relationships:
170-
self._upsert_relationship(rel)
231+
if isinstance(self.driver, neo4j.AsyncDriver):
232+
sem = asyncio.Semaphore(self.max_concurrency)
233+
node_tasks = [
234+
self._async_upsert_node(node, sem) for node in graph.nodes
235+
]
236+
await asyncio.gather(*node_tasks)
237+
238+
rel_tasks = [
239+
self._async_upsert_relationship(rel, sem)
240+
for rel in graph.relationships
241+
]
242+
await asyncio.gather(*rel_tasks)
243+
else:
244+
for node in graph.nodes:
245+
self._upsert_node(node)
246+
247+
for rel in graph.relationships:
248+
self._upsert_relationship(rel)
171249

172250
return KGWriterModel(status="SUCCESS")
173251
except neo4j.exceptions.ClientError as e:

0 commit comments

Comments
 (0)