Skip to content

Expose context from the orchestrator to the components #301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 20, 2025
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

### Added

- Added a `run_with_context` method to `Component`. This method has a `context_` parameter that contains information from the pipeline the component is being run from (e.g. the `run_id`)
- Added the `run_with_context` method to `Component`. This method includes a `context_` parameter, which provides information about the pipeline from which the component is executed (e.g., the `run_id`). It also enables the component to send events to the pipeline's callback function.


## 1.6.0
Expand Down
8 changes: 6 additions & 2 deletions docs/source/user_guide_pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ See :ref:`pipelineevent` and :ref:`taskevent` to see what is sent in each event
Send Events from Components
===========================

Components can send notifications about their progress using the `notify` function from
the `context_`:
Components can send progress notifications using the `notify` function from
`context_` by implementing the `run_from_context` method:

.. code:: python
Expand All @@ -200,3 +200,7 @@ the `context_`:
return IntResultModel(result = number1 + number2)
This will send an `TASK_PROGRESS` event to the pipeline callback.

.. note::

In a future release, the `context_` parameter will be added to the `run` method.
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ class MultiplicationComponent(Component):
def __init__(self, f: int) -> None:
self.f = f

async def run(self, numbers: list[int]) -> MultiplyComponentResult:
return MultiplyComponentResult(result=[])

async def multiply_number(
self,
context_: RunContext,
Expand All @@ -39,6 +36,8 @@ async def multiply_number(
)
return self.f * number

# implementing `run_with_context` to get access to
# the pipeline's RunContext:
async def run_with_context(
self,
context_: RunContext,
Expand Down
9 changes: 5 additions & 4 deletions src/neo4j_graphrag/experimental/pipeline/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,12 @@ class Component(metaclass=ComponentMeta):
component_outputs: dict[str, dict[str, str | bool | type]]

async def run(self, *args: Any, **kwargs: Any) -> DataModel:
"""This function is planned for deprecation in a future release.
"""Run the component and return its result.
Note: if `run_with_context` is implemented, this method will not be used.
"""
raise NotImplementedError(
"You must implement the `run` or `run_with_context` method. "
"`run` method will be marked for deprecation in a future release."
)

async def run_with_context(
Expand All @@ -102,8 +101,10 @@ async def run_with_context(
that can be used to send events from the component to
the pipeline callback.
For now, it defaults to calling the `run` method, but it
is meant to replace the `run` method in a future release.
This feature will be moved to the `run` method in a future
release.
It defaults to calling the `run` method to prevent any breaking change.
"""
# default behavior to prevent a breaking change
return await self.run(*args, **kwargs)
4 changes: 3 additions & 1 deletion src/neo4j_graphrag/experimental/pipeline/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
run_id=self.run_id,
task_name=task.name,
)
context = RunContext(run_id=self.run_id, task_name=task.name, notifier=notifier)
context = RunContext(
run_id=self.run_id, task_name=task.name, _notifier=notifier
)
res = await task.run(context, inputs)
await self.set_task_status(task.name, RunStatus.DONE)
await self.event_notifier.notify_task_finished(self.run_id, task.name, res)
Expand Down
6 changes: 3 additions & 3 deletions src/neo4j_graphrag/experimental/pipeline/types/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class RunContext(BaseModel):

run_id: str
task_name: str
notifier: Optional[TaskProgressCallbackProtocol] = None
_notifier: Optional[TaskProgressCallbackProtocol] = None

model_config = ConfigDict(arbitrary_types_allowed=True)

async def notify(self, message: str, data: dict[str, Any]) -> None:
if self.notifier:
await self.notifier(message=message, data=data)
if self._notifier:
await self._notifier(message=message, data=data)
2 changes: 1 addition & 1 deletion tests/unit/experimental/pipeline/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def test_component_run_with_context() -> None:
c = ComponentMultiplyWithContext()
notifier_mock = AsyncMock()
result = await c.run_with_context(
RunContext(run_id="run_id", task_name="task_name", notifier=notifier_mock),
RunContext(run_id="run_id", task_name="task_name", _notifier=notifier_mock),
number1=1,
number2=2,
)
Expand Down
Loading