Skip to content

Async pipeline improvements #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
67d430c
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python
stellasia Jun 25, 2024
e965499
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python
stellasia Jun 26, 2024
ed0baa7
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python
stellasia Jul 3, 2024
ea232ff
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python
stellasia Jul 16, 2024
43c7b3c
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python in…
stellasia Jul 17, 2024
8367daa
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python
stellasia Aug 1, 2024
3c3c00e
Merge remote-tracking branch 'origin/main'
stellasia Aug 1, 2024
7182523
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python
stellasia Aug 6, 2024
212a5a3
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python
stellasia Aug 13, 2024
32364c6
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python
stellasia Aug 29, 2024
f481025
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python
stellasia Sep 3, 2024
56435bf
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python
stellasia Sep 3, 2024
dc2fe93
Add failing test
stellasia Sep 4, 2024
2fb6448
Define a "run_id" in Orchestrator - save results per run_id
stellasia Sep 4, 2024
8d48c5d
Make unit test work
stellasia Sep 4, 2024
5b6a7e3
Make intermediate results accessible from outside pipeline for invest…
stellasia Sep 4, 2024
84f9b7f
Remove unused imports
stellasia Sep 4, 2024
a140774
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python in…
stellasia Sep 4, 2024
f7d7d7d
Update examples and CHANGELOG
stellasia Sep 4, 2024
fbc8391
Cleaning: remove deprecated code
stellasia Sep 4, 2024
439c5ad
Fix ruff
stellasia Sep 4, 2024
2156c65
Fix examples
stellasia Sep 4, 2024
5184688
Fix examples again
stellasia Sep 4, 2024
233a94e
Move status to store
stellasia Sep 4, 2024
11a47b7
PR reviews
stellasia Sep 6, 2024
16c8ff1
Removing useless status assignment
stellasia Sep 6, 2024
d7baf2a
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python in…
stellasia Sep 6, 2024
78c10ca
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python in…
stellasia Sep 6, 2024
9b50d4b
Remove unused import
stellasia Sep 6, 2024
59390fd
Move status to store
stellasia Sep 4, 2024
28675da
Merge remote-tracking branch 'origin/async-pipeline-improvements' int…
stellasia Sep 6, 2024
e957db5
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python in…
stellasia Sep 8, 2024
153d73e
Return RunStatus from method
stellasia Sep 8, 2024
f44fc9b
Fix bad merge
stellasia Sep 8, 2024
144f325
Fix comments
stellasia Sep 8, 2024
5fc686b
Deal with None statuses in the method dedicated to fetching status - …
stellasia Sep 10, 2024
fcfdf23
Fix error message
stellasia Sep 10, 2024
3a8936d
Update error message
stellasia Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 48 additions & 57 deletions src/neo4j_genai/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@

class RunStatus(enum.Enum):
UNKNOWN = "UNKNOWN"
SCHEDULED = "SCHEDULED"
WAITING = "WAITING"
RUNNING = "RUNNING"
SKIP = "SKIP"
DONE = "DONE"


Expand All @@ -76,37 +73,6 @@ def __init__(self, name: str, component: Component):
"""
super().__init__(name, {})
self.component = component
self.status: dict[str, RunStatus] = {}
self._lock = asyncio.Lock()
"""This lock is used to make sure we're not trying
to update the status in //. This should prevent the task to
be executed multiple times because the status was not known
by the orchestrator.
"""

async def set_status(self, run_id: str, status: RunStatus) -> None:
"""Set a new status

Args:
run_id (str): Unique ID for the current pipeline run
status (RunStatus): New status

Raises:
PipelineStatusUpdateError if the new status is not
compatible with the current one.
"""
async with self._lock:
current_status = self.status.get(run_id)
if status == current_status:
raise PipelineStatusUpdateError()
if status == RunStatus.RUNNING and current_status == RunStatus.DONE:
# can't go back to RUNNING from DONE
raise PipelineStatusUpdateError()
self.status[run_id] = status

async def read_status(self, run_id: str) -> RunStatus:
async with self._lock:
return self.status.get(run_id, RunStatus.UNKNOWN)

async def execute(self, **kwargs: Any) -> RunResult | None:
"""Execute the task
Expand Down Expand Up @@ -163,31 +129,52 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
None
"""
input_config = await self.get_input_config_for_task(task)
inputs = self.get_component_inputs(task.name, input_config, data)
inputs = await self.get_component_inputs(task.name, input_config, data)
try:
await task.set_status(self.run_id, RunStatus.RUNNING)
await self.set_task_status(task.name, RunStatus.RUNNING)
except PipelineStatusUpdateError:
logger.info(
f"Component {task.name} already running or done {task.status.get(self.run_id)}"
)
logger.info(f"Component {task.name} already running or done")
return None
res = await task.run(inputs)
await task.set_status(self.run_id, RunStatus.DONE)
await self.set_task_status(task.name, RunStatus.DONE)
if res:
await self.on_task_complete(data=data, task=task, result=res)

async def set_task_status(self, task_name: str, status: RunStatus) -> None:
"""Set a new status

Args:
task_name (str): Name of the component
status (RunStatus): New status

Raises:
PipelineStatusUpdateError if the new status is not
compatible with the current one.
"""
# prevent the method from being called by two concurrent async calls
async with asyncio.Lock():
current_status = await self.get_status_for_component(task_name)
if status == current_status:
raise PipelineStatusUpdateError(f"Status is already '{status}'")
if status == RunStatus.RUNNING and current_status == RunStatus.DONE:
raise PipelineStatusUpdateError("Can't go from DONE to RUNNING")
return await self.pipeline.store.add_status_for_component(
self.run_id, task_name, status.value
)

async def on_task_complete(
self, data: dict[str, Any], task: TaskPipelineNode, result: RunResult
) -> None:
"""When a given task is complete, it will call this method
to find the next tasks to run.
"""
# first call the method for the pipeline
# this is where the results can be saved
# first save this component results
res_to_save = None
if result.result:
res_to_save = result.result.model_dump()
self.add_result_for_component(task.name, res_to_save, is_final=task.is_leaf())
await self.add_result_for_component(
task.name, res_to_save, is_final=task.is_leaf()
)
# then get the next tasks to be executed
# and run them in //
await asyncio.gather(*[self.run_task(n, data) async for n in self.next(task)])
Expand All @@ -200,8 +187,7 @@ async def check_dependencies_complete(self, task: TaskPipelineNode) -> None:
"""
dependencies = self.pipeline.previous_edges(task.name)
for d in dependencies:
start_node = self.pipeline.get_node_by_name(d.start)
d_status = await start_node.read_status(self.run_id)
d_status = await self.get_status_for_component(d.start)
if d_status != RunStatus.DONE:
logger.warning(
f"Missing dependency {d.start} for {task.name} (status: {d_status})"
Expand All @@ -223,7 +209,7 @@ async def next(
for next_edge in possible_next:
next_node = self.pipeline.get_node_by_name(next_edge.end)
# check status
next_node_status = await next_node.read_status(self.run_id)
next_node_status = await self.get_status_for_component(next_node.name)
if next_node_status in [RunStatus.RUNNING, RunStatus.DONE]:
# already running
continue
Expand Down Expand Up @@ -251,8 +237,7 @@ async def get_input_config_for_task(self, task: TaskPipelineNode) -> dict[str, s
# make sure dependencies are satisfied
# and save the inputs defs that needs to be propagated from parent components
for prev_edge in self.pipeline.previous_edges(task.name):
prev_node = self.pipeline.get_node_by_name(prev_edge.start)
prev_status = await prev_node.read_status(self.run_id)
prev_status = await self.get_status_for_component(prev_edge.start)
if prev_status != RunStatus.DONE:
logger.critical(f"Missing dependency {prev_edge.start}")
raise PipelineMissingDependencyError(f"{prev_edge.start} not ready")
Expand All @@ -261,7 +246,7 @@ async def get_input_config_for_task(self, task: TaskPipelineNode) -> dict[str, s
input_config.update(**prev_edge_data)
return input_config

def get_component_inputs(
async def get_component_inputs(
self,
component_name: str,
input_config: dict[str, Any],
Expand All @@ -287,7 +272,7 @@ def get_component_inputs(
# component as input
component = mapping
output_param = None
component_result = self.get_results_for_component(component)
component_result = await self.get_results_for_component(component)
if output_param is not None:
value = component_result.get(output_param)
else:
Expand All @@ -299,25 +284,31 @@ def get_component_inputs(
component_inputs[parameter] = value
return component_inputs

def add_result_for_component(
async def add_result_for_component(
self, name: str, result: dict[str, Any] | None, is_final: bool = False
) -> None:
"""This is where we save the results in the result store and, optionally,
in the final result store.
"""
self.pipeline.store.add_result_for_component(self.run_id, name, result)
await self.pipeline.store.add_result_for_component(self.run_id, name, result)
if is_final:
# The pipeline only returns the results
# of the leaf nodes
# TODO: make this configurable in the future.
existing_results = self.pipeline.final_results.get(self.run_id) or {}
existing_results = await self.pipeline.final_results.get(self.run_id) or {}
existing_results[name] = result
self.pipeline.final_results.add(
await self.pipeline.final_results.add(
self.run_id, existing_results, overwrite=True
)

def get_results_for_component(self, name: str) -> Any:
return self.pipeline.store.get_result_for_component(self.run_id, name)
async def get_results_for_component(self, name: str) -> Any:
return await self.pipeline.store.get_result_for_component(self.run_id, name)

async def get_status_for_component(self, name: str) -> RunStatus:
status = await self.pipeline.store.get_status_for_component(self.run_id, name)
if status is None:
return RunStatus.UNKNOWN
return RunStatus(status)

async def run(self, data: dict[str, Any]) -> None:
"""Run the pipline, starting from the root nodes
Expand Down Expand Up @@ -500,5 +491,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:
)
return PipelineResult(
run_id=orchestrator.run_id,
result=self.final_results.get(orchestrator.run_id),
result=await self.final_results.get(orchestrator.run_id),
)
53 changes: 37 additions & 16 deletions src/neo4j_genai/experimental/pipeline/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
from __future__ import annotations

import abc
import asyncio
from typing import Any


class Store(abc.ABC):
"""An interface to save component outputs"""

@abc.abstractmethod
def add(self, key: str, value: Any, overwrite: bool = True) -> None:
async def add(self, key: str, value: Any, overwrite: bool = True) -> None:
"""
Args:
key (str): The key to access the data.
Expand All @@ -41,7 +42,7 @@ def add(self, key: str, value: Any, overwrite: bool = True) -> None:
pass

@abc.abstractmethod
def get(self, key: str) -> Any:
async def get(self, key: str) -> Any:
"""Retrieve value for `key`.
If key not found, returns None.
"""
Expand All @@ -62,16 +63,32 @@ def empty(self) -> None:

class ResultStore(Store, abc.ABC):
@staticmethod
def get_key(run_id: str, task_name: str) -> str:
return f"{run_id}:{task_name}"
def get_key(run_id: str, task_name: str, suffix: str = "") -> str:
key = f"{run_id}:{task_name}"
if suffix:
key += f":{suffix}"
return key

async def add_status_for_component(
self,
run_id: str,
task_name: str,
status: str,
) -> None:
await self.add(
self.get_key(run_id, task_name, "status"), status, overwrite=True
)

async def get_status_for_component(self, run_id: str, task_name: str) -> Any:
return await self.get(self.get_key(run_id, task_name, "status"))

def add_result_for_component(
async def add_result_for_component(
self, run_id: str, task_name: str, result: Any, overwrite: bool = False
) -> None:
self.add(self.get_key(run_id, task_name), result, overwrite=overwrite)
await self.add(self.get_key(run_id, task_name), result, overwrite=overwrite)

def get_result_for_component(self, run_id: str, task_name: str) -> Any:
return self.get(self.get_key(run_id, task_name))
async def get_result_for_component(self, run_id: str, task_name: str) -> Any:
return await self.get(self.get_key(run_id, task_name))


class InMemoryStore(ResultStore):
Expand All @@ -80,14 +97,18 @@ class InMemoryStore(ResultStore):

def __init__(self) -> None:
self._data: dict[str, Any] = {}

def add(self, key: str, value: Any, overwrite: bool = True) -> None:
if (not overwrite) and key in self._data:
raise KeyError(f"{key} already exists")
self._data[key] = value

def get(self, key: str) -> Any:
return self._data.get(key)
self._lock = asyncio.Lock()
"""This lock is used to prevent read while a write in ongoing and vice-versa."""

async def add(self, key: str, value: Any, overwrite: bool = True) -> None:
async with self._lock:
if (not overwrite) and key in self._data:
raise KeyError(f"{key} already exists")
self._data[key] = value

async def get(self, key: str) -> Any:
async with self._lock:
return self._data.get(key)

def all(self) -> dict[str, Any]:
return self._data
Expand Down
16 changes: 12 additions & 4 deletions tests/e2e/test_kg_builder_pipeline_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,13 @@ async def test_pipeline_builder_happy_path(
assert res.run_id is not None
assert res.result == {"writer": {"status": "SUCCESS"}}
# check component's results
chunks = kg_builder_pipeline.store.get_result_for_component(res.run_id, "splitter")
chunks = await kg_builder_pipeline.store.get_result_for_component(
res.run_id, "splitter"
)
assert len(chunks["chunks"]) == 3
graph = kg_builder_pipeline.store.get_result_for_component(res.run_id, "extractor")
graph = await kg_builder_pipeline.store.get_result_for_component(
res.run_id, "extractor"
)
# 3 entities + 3 chunks + 1 document
nodes = graph["nodes"]
assert len(nodes) == 7
Expand Down Expand Up @@ -463,9 +467,13 @@ async def test_pipeline_builder_failing_chunk_do_not_raise(
assert res.run_id is not None
assert res.result == {"writer": {"status": "SUCCESS"}}
# check component's results
chunks = kg_builder_pipeline.store.get_result_for_component(res.run_id, "splitter")
chunks = await kg_builder_pipeline.store.get_result_for_component(
res.run_id, "splitter"
)
assert len(chunks["chunks"]) == 3
graph = kg_builder_pipeline.store.get_result_for_component(res.run_id, "extractor")
graph = await kg_builder_pipeline.store.get_result_for_component(
res.run_id, "extractor"
)
# 3 entities + 3 chunks
nodes = graph["nodes"]
assert len(nodes) == 6
Expand Down
Loading