Skip to content

Commit b65b34f

Browse files
authored
Async pipeline improvements (#123)
* Add failing test * Define a "run_id" in Orchestrator - save results per run_id * Make unit test work * Make intermediate results accessible from outside pipeline for investigation * Remove unused imports * Update examples and CHANGELOG * Cleaning: remove deprecated code * Fix ruff * Fix examples * Fix examples again * Move status to store * PR reviews * Removing useless status assignment * Remove unused import * Move status to store * Return RunStatus from method * Fix bad merge * Fix comments * Deal with None statuses in the method dedicated to fetching status - Remove unused statuses * Fix error message * Update error message
1 parent 19fbace commit b65b34f

File tree

5 files changed

+165
-102
lines changed

5 files changed

+165
-102
lines changed

src/neo4j_genai/experimental/pipeline/pipeline.py

Lines changed: 48 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@
4848

4949
class RunStatus(enum.Enum):
5050
UNKNOWN = "UNKNOWN"
51-
SCHEDULED = "SCHEDULED"
52-
WAITING = "WAITING"
5351
RUNNING = "RUNNING"
54-
SKIP = "SKIP"
5552
DONE = "DONE"
5653

5754

@@ -76,37 +73,6 @@ def __init__(self, name: str, component: Component):
7673
"""
7774
super().__init__(name, {})
7875
self.component = component
79-
self.status: dict[str, RunStatus] = {}
80-
self._lock = asyncio.Lock()
81-
"""This lock is used to make sure we're not trying
82-
to update the status in //. This should prevent the task to
83-
be executed multiple times because the status was not known
84-
by the orchestrator.
85-
"""
86-
87-
async def set_status(self, run_id: str, status: RunStatus) -> None:
88-
"""Set a new status
89-
90-
Args:
91-
run_id (str): Unique ID for the current pipeline run
92-
status (RunStatus): New status
93-
94-
Raises:
95-
PipelineStatusUpdateError if the new status is not
96-
compatible with the current one.
97-
"""
98-
async with self._lock:
99-
current_status = self.status.get(run_id)
100-
if status == current_status:
101-
raise PipelineStatusUpdateError()
102-
if status == RunStatus.RUNNING and current_status == RunStatus.DONE:
103-
# can't go back to RUNNING from DONE
104-
raise PipelineStatusUpdateError()
105-
self.status[run_id] = status
106-
107-
async def read_status(self, run_id: str) -> RunStatus:
108-
async with self._lock:
109-
return self.status.get(run_id, RunStatus.UNKNOWN)
11076

11177
async def execute(self, **kwargs: Any) -> RunResult | None:
11278
"""Execute the task
@@ -163,31 +129,52 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
163129
None
164130
"""
165131
input_config = await self.get_input_config_for_task(task)
166-
inputs = self.get_component_inputs(task.name, input_config, data)
132+
inputs = await self.get_component_inputs(task.name, input_config, data)
167133
try:
168-
await task.set_status(self.run_id, RunStatus.RUNNING)
134+
await self.set_task_status(task.name, RunStatus.RUNNING)
169135
except PipelineStatusUpdateError:
170-
logger.info(
171-
f"Component {task.name} already running or done {task.status.get(self.run_id)}"
172-
)
136+
logger.info(f"Component {task.name} already running or done")
173137
return None
174138
res = await task.run(inputs)
175-
await task.set_status(self.run_id, RunStatus.DONE)
139+
await self.set_task_status(task.name, RunStatus.DONE)
176140
if res:
177141
await self.on_task_complete(data=data, task=task, result=res)
178142

143+
async def set_task_status(self, task_name: str, status: RunStatus) -> None:
144+
"""Set a new status
145+
146+
Args:
147+
task_name (str): Name of the component
148+
status (RunStatus): New status
149+
150+
Raises:
151+
PipelineStatusUpdateError if the new status is not
152+
compatible with the current one.
153+
"""
154+
# prevent the method from being called by two concurrent async calls
155+
async with asyncio.Lock():
156+
current_status = await self.get_status_for_component(task_name)
157+
if status == current_status:
158+
raise PipelineStatusUpdateError(f"Status is already '{status}'")
159+
if status == RunStatus.RUNNING and current_status == RunStatus.DONE:
160+
raise PipelineStatusUpdateError("Can't go from DONE to RUNNING")
161+
return await self.pipeline.store.add_status_for_component(
162+
self.run_id, task_name, status.value
163+
)
164+
179165
async def on_task_complete(
180166
self, data: dict[str, Any], task: TaskPipelineNode, result: RunResult
181167
) -> None:
182168
"""When a given task is complete, it will call this method
183169
to find the next tasks to run.
184170
"""
185-
# first call the method for the pipeline
186-
# this is where the results can be saved
171+
# first save this component results
187172
res_to_save = None
188173
if result.result:
189174
res_to_save = result.result.model_dump()
190-
self.add_result_for_component(task.name, res_to_save, is_final=task.is_leaf())
175+
await self.add_result_for_component(
176+
task.name, res_to_save, is_final=task.is_leaf()
177+
)
191178
# then get the next tasks to be executed
192179
# and run them in //
193180
await asyncio.gather(*[self.run_task(n, data) async for n in self.next(task)])
@@ -200,8 +187,7 @@ async def check_dependencies_complete(self, task: TaskPipelineNode) -> None:
200187
"""
201188
dependencies = self.pipeline.previous_edges(task.name)
202189
for d in dependencies:
203-
start_node = self.pipeline.get_node_by_name(d.start)
204-
d_status = await start_node.read_status(self.run_id)
190+
d_status = await self.get_status_for_component(d.start)
205191
if d_status != RunStatus.DONE:
206192
logger.warning(
207193
f"Missing dependency {d.start} for {task.name} (status: {d_status})"
@@ -223,7 +209,7 @@ async def next(
223209
for next_edge in possible_next:
224210
next_node = self.pipeline.get_node_by_name(next_edge.end)
225211
# check status
226-
next_node_status = await next_node.read_status(self.run_id)
212+
next_node_status = await self.get_status_for_component(next_node.name)
227213
if next_node_status in [RunStatus.RUNNING, RunStatus.DONE]:
228214
# already running
229215
continue
@@ -251,8 +237,7 @@ async def get_input_config_for_task(self, task: TaskPipelineNode) -> dict[str, s
251237
# make sure dependencies are satisfied
252238
# and save the inputs defs that needs to be propagated from parent components
253239
for prev_edge in self.pipeline.previous_edges(task.name):
254-
prev_node = self.pipeline.get_node_by_name(prev_edge.start)
255-
prev_status = await prev_node.read_status(self.run_id)
240+
prev_status = await self.get_status_for_component(prev_edge.start)
256241
if prev_status != RunStatus.DONE:
257242
logger.critical(f"Missing dependency {prev_edge.start}")
258243
raise PipelineMissingDependencyError(f"{prev_edge.start} not ready")
@@ -261,7 +246,7 @@ async def get_input_config_for_task(self, task: TaskPipelineNode) -> dict[str, s
261246
input_config.update(**prev_edge_data)
262247
return input_config
263248

264-
def get_component_inputs(
249+
async def get_component_inputs(
265250
self,
266251
component_name: str,
267252
input_config: dict[str, Any],
@@ -287,7 +272,7 @@ def get_component_inputs(
287272
# component as input
288273
component = mapping
289274
output_param = None
290-
component_result = self.get_results_for_component(component)
275+
component_result = await self.get_results_for_component(component)
291276
if output_param is not None:
292277
value = component_result.get(output_param)
293278
else:
@@ -299,25 +284,31 @@ def get_component_inputs(
299284
component_inputs[parameter] = value
300285
return component_inputs
301286

302-
def add_result_for_component(
287+
async def add_result_for_component(
303288
self, name: str, result: dict[str, Any] | None, is_final: bool = False
304289
) -> None:
305290
"""This is where we save the results in the result store and, optionally,
306291
in the final result store.
307292
"""
308-
self.pipeline.store.add_result_for_component(self.run_id, name, result)
293+
await self.pipeline.store.add_result_for_component(self.run_id, name, result)
309294
if is_final:
310295
# The pipeline only returns the results
311296
# of the leaf nodes
312297
# TODO: make this configurable in the future.
313-
existing_results = self.pipeline.final_results.get(self.run_id) or {}
298+
existing_results = await self.pipeline.final_results.get(self.run_id) or {}
314299
existing_results[name] = result
315-
self.pipeline.final_results.add(
300+
await self.pipeline.final_results.add(
316301
self.run_id, existing_results, overwrite=True
317302
)
318303

319-
def get_results_for_component(self, name: str) -> Any:
320-
return self.pipeline.store.get_result_for_component(self.run_id, name)
304+
async def get_results_for_component(self, name: str) -> Any:
305+
return await self.pipeline.store.get_result_for_component(self.run_id, name)
306+
307+
async def get_status_for_component(self, name: str) -> RunStatus:
308+
status = await self.pipeline.store.get_status_for_component(self.run_id, name)
309+
if status is None:
310+
return RunStatus.UNKNOWN
311+
return RunStatus(status)
321312

322313
async def run(self, data: dict[str, Any]) -> None:
323314
"""Run the pipline, starting from the root nodes
@@ -500,5 +491,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:
500491
)
501492
return PipelineResult(
502493
run_id=orchestrator.run_id,
503-
result=self.final_results.get(orchestrator.run_id),
494+
result=await self.final_results.get(orchestrator.run_id),
504495
)

src/neo4j_genai/experimental/pipeline/stores.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
from __future__ import annotations
2020

2121
import abc
22+
import asyncio
2223
from typing import Any
2324

2425

2526
class Store(abc.ABC):
2627
"""An interface to save component outputs"""
2728

2829
@abc.abstractmethod
29-
def add(self, key: str, value: Any, overwrite: bool = True) -> None:
30+
async def add(self, key: str, value: Any, overwrite: bool = True) -> None:
3031
"""
3132
Args:
3233
key (str): The key to access the data.
@@ -41,7 +42,7 @@ def add(self, key: str, value: Any, overwrite: bool = True) -> None:
4142
pass
4243

4344
@abc.abstractmethod
44-
def get(self, key: str) -> Any:
45+
async def get(self, key: str) -> Any:
4546
"""Retrieve value for `key`.
4647
If key not found, returns None.
4748
"""
@@ -62,16 +63,32 @@ def empty(self) -> None:
6263

6364
class ResultStore(Store, abc.ABC):
6465
@staticmethod
65-
def get_key(run_id: str, task_name: str) -> str:
66-
return f"{run_id}:{task_name}"
66+
def get_key(run_id: str, task_name: str, suffix: str = "") -> str:
67+
key = f"{run_id}:{task_name}"
68+
if suffix:
69+
key += f":{suffix}"
70+
return key
71+
72+
async def add_status_for_component(
73+
self,
74+
run_id: str,
75+
task_name: str,
76+
status: str,
77+
) -> None:
78+
await self.add(
79+
self.get_key(run_id, task_name, "status"), status, overwrite=True
80+
)
81+
82+
async def get_status_for_component(self, run_id: str, task_name: str) -> Any:
83+
return await self.get(self.get_key(run_id, task_name, "status"))
6784

68-
def add_result_for_component(
85+
async def add_result_for_component(
6986
self, run_id: str, task_name: str, result: Any, overwrite: bool = False
7087
) -> None:
71-
self.add(self.get_key(run_id, task_name), result, overwrite=overwrite)
88+
await self.add(self.get_key(run_id, task_name), result, overwrite=overwrite)
7289

73-
def get_result_for_component(self, run_id: str, task_name: str) -> Any:
74-
return self.get(self.get_key(run_id, task_name))
90+
async def get_result_for_component(self, run_id: str, task_name: str) -> Any:
91+
return await self.get(self.get_key(run_id, task_name))
7592

7693

7794
class InMemoryStore(ResultStore):
@@ -80,14 +97,18 @@ class InMemoryStore(ResultStore):
8097

8198
def __init__(self) -> None:
8299
self._data: dict[str, Any] = {}
83-
84-
def add(self, key: str, value: Any, overwrite: bool = True) -> None:
85-
if (not overwrite) and key in self._data:
86-
raise KeyError(f"{key} already exists")
87-
self._data[key] = value
88-
89-
def get(self, key: str) -> Any:
90-
return self._data.get(key)
100+
self._lock = asyncio.Lock()
101+
"""This lock is used to prevent read while a write in ongoing and vice-versa."""
102+
103+
async def add(self, key: str, value: Any, overwrite: bool = True) -> None:
104+
async with self._lock:
105+
if (not overwrite) and key in self._data:
106+
raise KeyError(f"{key} already exists")
107+
self._data[key] = value
108+
109+
async def get(self, key: str) -> Any:
110+
async with self._lock:
111+
return self._data.get(key)
91112

92113
def all(self) -> dict[str, Any]:
93114
return self._data

tests/e2e/test_kg_builder_pipeline_e2e.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,13 @@ async def test_pipeline_builder_happy_path(
261261
assert res.run_id is not None
262262
assert res.result == {"writer": {"status": "SUCCESS"}}
263263
# check component's results
264-
chunks = kg_builder_pipeline.store.get_result_for_component(res.run_id, "splitter")
264+
chunks = await kg_builder_pipeline.store.get_result_for_component(
265+
res.run_id, "splitter"
266+
)
265267
assert len(chunks["chunks"]) == 3
266-
graph = kg_builder_pipeline.store.get_result_for_component(res.run_id, "extractor")
268+
graph = await kg_builder_pipeline.store.get_result_for_component(
269+
res.run_id, "extractor"
270+
)
267271
# 3 entities + 3 chunks + 1 document
268272
nodes = graph["nodes"]
269273
assert len(nodes) == 7
@@ -463,9 +467,13 @@ async def test_pipeline_builder_failing_chunk_do_not_raise(
463467
assert res.run_id is not None
464468
assert res.result == {"writer": {"status": "SUCCESS"}}
465469
# check component's results
466-
chunks = kg_builder_pipeline.store.get_result_for_component(res.run_id, "splitter")
470+
chunks = await kg_builder_pipeline.store.get_result_for_component(
471+
res.run_id, "splitter"
472+
)
467473
assert len(chunks["chunks"]) == 3
468-
graph = kg_builder_pipeline.store.get_result_for_component(res.run_id, "extractor")
474+
graph = await kg_builder_pipeline.store.get_result_for_component(
475+
res.run_id, "extractor"
476+
)
469477
# 3 entities + 3 chunks
470478
nodes = graph["nodes"]
471479
assert len(nodes) == 6

0 commit comments

Comments
 (0)