Skip to content

Move status to store #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
20 changes: 17 additions & 3 deletions docs/source/user_guide_pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ Here's how to create a simple pipeline and propagate results from one component
pipe.add_component(ComponentAdd(), "a")
pipe.add_component(ComponentAdd(), "b")

pipe.connect("a", "b", {"number2": "a.result"})
asyncio.run(pipe.run({"a": {"number1": 10, "number2": 1}, "b": {"number1": 4}))
pipe.connect("a", "b", input_config={"number2": "a.result"})
asyncio.run(pipe.run({"a": {"number1": 10, "number2": 1}, "b": {"number1": 4}}))
# result: 10+1+4 = 15

1. First, a pipeline is created, and two components named "a" and "b" are added to it.
Expand All @@ -79,6 +79,20 @@ The data flow is illustrated in the diagram below:
Component "b" -> 15
4 -------------------------/

.. warning::
.. warning:: Cyclic graph

Cycles are not allowed in a Pipeline.


.. warning:: Ignored user inputs

If inputs are provided both by user in the `pipeline.run` method and as
`input_config` in a connect method, the user input will be ignored. Take for
instance the following pipeline, adapted from the previous one:

.. code:: python

pipe.connect("a", "b", input_config={"number2": "a.result"})
asyncio.run(pipe.run({"a": {"number1": 10, "number2": 1}, "b": {"number1": 4, "number2": 42}}))

The result will still be **15** because the user input `"number2": 42` is ignored.
27 changes: 15 additions & 12 deletions examples/pipeline/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from neo4j_genai.embeddings.openai import OpenAIEmbeddings
from neo4j_genai.experimental.pipeline import Component, Pipeline
from neo4j_genai.experimental.pipeline.component import DataModel
from neo4j_genai.experimental.pipeline.pipeline import PipelineResult
from neo4j_genai.experimental.pipeline.types import (
ComponentConfig,
ConnectionConfig,
Expand All @@ -38,35 +39,37 @@
from neo4j_genai.retrievers.base import Retriever


class StringDataModel(DataModel):
result: str
class ComponentResultDataModel(DataModel):
"""A simple DataModel with a single text field"""

text: str


class RetrieverComponent(Component):
def __init__(self, retriever: Retriever) -> None:
self.retriever = retriever

async def run(self, query: str) -> StringDataModel:
async def run(self, query: str) -> ComponentResultDataModel:
res = self.retriever.search(query_text=query)
return StringDataModel(result="\n".join(c.content for c in res.items))
return ComponentResultDataModel(text="\n".join(c.content for c in res.items))


class PromptTemplateComponent(Component):
def __init__(self, prompt: PromptTemplate) -> None:
self.prompt = prompt

async def run(self, query: str, context: List[str]) -> StringDataModel:
async def run(self, query: str, context: List[str]) -> ComponentResultDataModel:
prompt = self.prompt.format(query, context, examples="")
return StringDataModel(result=prompt)
return ComponentResultDataModel(text=prompt)


class LLMComponent(Component):
def __init__(self, llm: LLMInterface) -> None:
self.llm = llm

async def run(self, prompt: str) -> StringDataModel:
async def run(self, prompt: str) -> ComponentResultDataModel:
llm_response = self.llm.invoke(prompt)
return StringDataModel(result=llm_response.content)
return ComponentResultDataModel(text=llm_response.content)


if __name__ == "__main__":
Expand Down Expand Up @@ -97,21 +100,21 @@ async def run(self, prompt: str) -> StringDataModel:
ConnectionConfig(
start="retrieve",
end="augment",
input_config={"context": "retrieve.result"},
input_config={"context": "retrieve.text"},
),
ConnectionConfig(
start="augment",
end="generate",
input_config={"prompt": "augment.result"},
input_config={"prompt": "augment.text"},
),
],
)
)

query = "A movie about the US presidency"
result = asyncio.run(
pipe_output: PipelineResult = asyncio.run(
pipe.run({"retrieve": {"query": query}, "augment": {"query": query}})
)
print(result.result["generate"]["result"])
print(pipe_output.result["generate"]["text"])

driver.close()
110 changes: 55 additions & 55 deletions src/neo4j_genai/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@


class RunStatus(enum.Enum):
UNKNOWN = "UNKNOWN"
UNKNOWN = None
SCHEDULED = "SCHEDULED"
WAITING = "WAITING"
RUNNING = "RUNNING"
Expand Down Expand Up @@ -76,37 +76,6 @@ def __init__(self, name: str, component: Component):
"""
super().__init__(name, {})
self.component = component
self.status: dict[str, RunStatus] = {}
self._lock = asyncio.Lock()
"""This lock is used to make sure we're not trying
to update the status in //. This should prevent the task to
be executed multiple times because the status was not known
by the orchestrator.
"""

async def set_status(self, run_id: str, status: RunStatus) -> None:
"""Set a new status

Args:
run_id (str): Unique ID for the current pipeline run
status (RunStatus): New status

Raises:
PipelineStatusUpdateError if the new status is not
compatible with the current one.
"""
async with self._lock:
current_status = self.status.get(run_id)
if status == current_status:
raise PipelineStatusUpdateError()
if status == RunStatus.RUNNING and current_status == RunStatus.DONE:
# can't go back to RUNNING from DONE
raise PipelineStatusUpdateError()
self.status[run_id] = status

async def read_status(self, run_id: str) -> RunStatus:
async with self._lock:
return self.status.get(run_id, RunStatus.UNKNOWN)

async def execute(self, **kwargs: Any) -> RunResult | None:
"""Execute the task
Expand All @@ -130,8 +99,6 @@ async def run(self, inputs: dict[str, Any]) -> RunResult | None:
"""Main method to execute the task."""
logger.debug(f"TASK START {self.name=} {inputs=}")
res = await self.execute(**inputs)
if res is None:
return None
logger.debug(f"TASK RESULT {self.name=} {res=}")
return res

Expand Down Expand Up @@ -165,19 +132,45 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
None
"""
input_config = await self.get_input_config_for_task(task)
inputs = self.get_component_inputs(task.name, input_config, data)
inputs = await self.get_component_inputs(task.name, input_config, data)
try:
await task.set_status(self.run_id, RunStatus.RUNNING)
await self.set_task_status(task.name, RunStatus.RUNNING)
except PipelineStatusUpdateError:
logger.info(
f"Component {task.name} already running or done {task.status.get(self.run_id)}"
)
logger.info(f"Component {task.name} already running or done")
return None
res = await task.run(inputs)
await task.set_status(self.run_id, RunStatus.DONE)
await self.set_task_status(task.name, RunStatus.DONE)
if res:
await self.on_task_complete(data=data, task=task, result=res)

async def set_task_status(self, task_name: str, status: RunStatus) -> None:
"""Set a new status

Args:
task_name (str): Name of the component
status (RunStatus): New status

Raises:
PipelineStatusUpdateError if the new status is not
compatible with the current one.
"""
# Make the method async-safe against this kind of calls:
# 1: get status => UNKNOWN
# 2: get status => UNKNOWN
# 1: set status => RUNNING
# 2: set status => RUNNING
# that would cause two tasks to be started instead of one
async with asyncio.Lock():
current_status = RunStatus(await self.get_status_for_component(task_name))
if status == current_status:
raise PipelineStatusUpdateError()
if status == RunStatus.RUNNING and current_status == RunStatus.DONE:
# can't go back to RUNNING from DONE
raise PipelineStatusUpdateError()
return await self.pipeline.store.add_status_for_component(
self.run_id, task_name, status.value
)

async def on_task_complete(
self, data: dict[str, Any], task: TaskPipelineNode, result: RunResult
) -> None:
Expand All @@ -189,7 +182,9 @@ async def on_task_complete(
res_to_save = None
if result.result:
res_to_save = result.result.model_dump()
self.add_result_for_component(task.name, res_to_save, is_final=task.is_leaf())
await self.add_result_for_component(
task.name, res_to_save, is_final=task.is_leaf()
)
# then get the next tasks to be executed
# and run them in //
await asyncio.gather(*[self.run_task(n, data) async for n in self.next(task)])
Expand All @@ -202,8 +197,7 @@ async def check_dependencies_complete(self, task: TaskPipelineNode) -> None:
"""
dependencies = self.pipeline.previous_edges(task.name)
for d in dependencies:
start_node = self.pipeline.get_node_by_name(d.start)
d_status = await start_node.read_status(self.run_id)
d_status = RunStatus(await self.get_status_for_component(d.start))
if d_status != RunStatus.DONE:
logger.warning(
f"Missing dependency {d.start} for {task.name} (status: {d_status})"
Expand All @@ -225,7 +219,9 @@ async def next(
for next_edge in possible_next:
next_node = self.pipeline.get_node_by_name(next_edge.end)
# check status
next_node_status = await next_node.read_status(self.run_id)
next_node_status = RunStatus(
await self.get_status_for_component(next_node.name)
)
if next_node_status in [RunStatus.RUNNING, RunStatus.DONE]:
# already running
continue
Expand Down Expand Up @@ -253,8 +249,9 @@ async def get_input_config_for_task(self, task: TaskPipelineNode) -> dict[str, s
# make sure dependencies are satisfied
# and save the inputs defs that needs to be propagated from parent components
for prev_edge in self.pipeline.previous_edges(task.name):
prev_node = self.pipeline.get_node_by_name(prev_edge.start)
prev_status = await prev_node.read_status(self.run_id)
prev_status = RunStatus(
await self.get_status_for_component(prev_edge.start)
)
if prev_status != RunStatus.DONE:
logger.critical(f"Missing dependency {prev_edge.start}")
raise PipelineMissingDependencyError(f"{prev_edge.start} not ready")
Expand All @@ -263,7 +260,7 @@ async def get_input_config_for_task(self, task: TaskPipelineNode) -> dict[str, s
input_config.update(**prev_edge_data)
return input_config

def get_component_inputs(
async def get_component_inputs(
self,
component_name: str,
input_config: dict[str, Any],
Expand All @@ -289,7 +286,7 @@ def get_component_inputs(
# component as input
component = mapping
output_param = None
component_result = self.get_results_for_component(component)
component_result = await self.get_results_for_component(component)
if output_param is not None:
value = component_result.get(output_param)
else:
Expand All @@ -301,25 +298,28 @@ def get_component_inputs(
component_inputs[parameter] = value
return component_inputs

def add_result_for_component(
async def add_result_for_component(
self, name: str, result: dict[str, Any] | None, is_final: bool = False
) -> None:
"""This is where we save the results in the result store and, optionally,
in the final result store.
"""
self.pipeline.store.add_result_for_component(self.run_id, name, result)
await self.pipeline.store.add_result_for_component(self.run_id, name, result)
if is_final:
# The pipeline only returns the results
# of the leaf nodes
# TODO: make this configurable in the future.
existing_results = self.pipeline.final_results.get(self.run_id) or {}
existing_results = await self.pipeline.final_results.get(self.run_id) or {}
existing_results[name] = result
self.pipeline.final_results.add(
await self.pipeline.final_results.add(
self.run_id, existing_results, overwrite=True
)

def get_results_for_component(self, name: str) -> Any:
return self.pipeline.store.get_result_for_component(self.run_id, name)
async def get_results_for_component(self, name: str) -> Any:
return await self.pipeline.store.get_result_for_component(self.run_id, name)

async def get_status_for_component(self, name: str) -> Any:
return await self.pipeline.store.get_status_for_component(self.run_id, name)

async def run(self, data: dict[str, Any]) -> None:
"""Run the pipline, starting from the root nodes
Expand Down Expand Up @@ -502,5 +502,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:
)
return PipelineResult(
run_id=orchestrator.run_id,
result=self.final_results.get(orchestrator.run_id),
result=await self.final_results.get(orchestrator.run_id),
)
Loading