Skip to content

Expose context from the orchestrator to the components #301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 20, 2025
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Added optional schema enforcement as a validation layer after entity and relation extraction.
- Introduced a linear hybrid search ranker for HybridRetriever and HybridCypherRetriever, allowing customizable ranking with an `alpha` parameter.
- Introduced SearchQueryParseError for handling invalid Lucene query strings in HybridRetriever and HybridCypherRetriever.
- Components can now be called with the `run_with_context` method that gets an extra `context_` argument containing information about the pipeline it's run from: the `run_id`, `task_name` and a `notify` function that can be used to send `TASK_PROGRESS` events to the same callback as the pipeline events.

### Fixed

Expand Down
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ are listed in [the last section of this file](#customize).
- [Export lexical graph creation into another pipeline](./customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py)
- [Build pipeline from config file](customize/build_graph/pipeline/from_config_files/pipeline_from_config_file.py)
- [Add event listener to get notification about Pipeline progress](./customize/build_graph/pipeline/pipeline_with_notifications.py)
- [Use component context to send notifications about Component progress](./customize/build_graph/pipeline/pipeline_with_component_notifications.py)


#### Components
Expand Down
2 changes: 1 addition & 1 deletion examples/build_graph/simple_kg_builder_from_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from neo4j_graphrag.embeddings import OpenAIEmbeddings
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.experimental.pipeline.types import (
from neo4j_graphrag.experimental.pipeline.types.schema import (
EntityInputType,
RelationInputType,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""This example demonstrates how to use event callback to receive notifications
about the component progress.
"""

from __future__ import annotations

import asyncio
import logging
from typing import Any

from neo4j_graphrag.experimental.pipeline import Pipeline, Component, DataModel
from neo4j_graphrag.experimental.pipeline.notification import Event, EventType
from neo4j_graphrag.experimental.pipeline.types.context import RunContext

logger = logging.getLogger(__name__)
logging.basicConfig()
logger.setLevel(logging.INFO)


class MultiplyComponentResult(DataModel):
result: list[int]


class MultiplicationComponent(Component):
def __init__(self, f: int) -> None:
self.f = f

async def run(self, numbers: list[int]) -> MultiplyComponentResult:
return MultiplyComponentResult(result=[])

async def multiply_number(
self,
context_: RunContext,
number: int,
) -> int:
await context_.notify(
message=f"Processing number {number}",
data={"number_processed": number},
)
return self.f * number

async def run_with_context(
self,
context_: RunContext,
numbers: list[int],
**kwargs: Any,
) -> MultiplyComponentResult:
result = await asyncio.gather(
*[
self.multiply_number(
context_,
number,
)
for number in numbers
]
)
return MultiplyComponentResult(result=result)


async def event_handler(event: Event) -> None:
"""Function can do anything about the event,
here we're just logging it if it's a pipeline-level event.
"""
if event.event_type == EventType.TASK_PROGRESS:
logger.warning(event)
else:
logger.info(event)


async def main() -> None:
""" """
pipe = Pipeline(
callback=event_handler,
)
# define the components
pipe.add_component(
MultiplicationComponent(f=2),
"multiply_by_2",
)
pipe.add_component(
MultiplicationComponent(f=10),
"multiply_by_10",
)
# define the execution order of component
# and how the output of previous components must be used
pipe.connect(
"multiply_by_2",
"multiply_by_10",
input_config={"numbers": "multiply_by_2.result"},
)
# user input:
pipe_inputs_1 = {
"multiply_by_2": {
"numbers": [1, 2, 5, 4],
},
}
pipe_inputs_2 = {
"multiply_by_2": {
"numbers": [3, 10, 1],
}
}
# run the pipeline
await asyncio.gather(
pipe.run(pipe_inputs_1),
pipe.run(pipe_inputs_2),
)


if __name__ == "__main__":
asyncio.run(main())
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from neo4j_graphrag.experimental.pipeline import Pipeline
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.experimental.pipeline.types import Event
from neo4j_graphrag.experimental.pipeline.notification import Event

logger = logging.getLogger(__name__)
logging.basicConfig()
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/experimental/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from neo4j_graphrag.exceptions import SchemaValidationError
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
from neo4j_graphrag.experimental.pipeline.types import (
from neo4j_graphrag.experimental.pipeline.types.schema import (
EntityInputType,
RelationInputType,
)
Expand Down
66 changes: 39 additions & 27 deletions src/neo4j_graphrag/experimental/pipeline/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from pydantic import BaseModel

from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.utils.validation import issubclass_safe

Expand All @@ -36,34 +37,39 @@ def __new__(
) -> type:
# extract required inputs and outputs from the run method signature
run_method = attrs.get("run")
if run_method is not None:
sig = inspect.signature(run_method)
attrs["component_inputs"] = {
param.name: {
"has_default": param.default != inspect.Parameter.empty,
"annotation": param.annotation,
}
for param in sig.parameters.values()
if param.name not in ("self", "kwargs")
run_context_method = attrs.get("run_with_context")
run = run_context_method or run_method
if run is None:
raise RuntimeError(
f"Either 'run' or 'run_with_context' must be implemented in component: '{name}'"
)
sig = inspect.signature(run)
attrs["component_inputs"] = {
param.name: {
"has_default": param.default != inspect.Parameter.empty,
"annotation": param.annotation,
}
# extract returned fields from the run method return type hint
return_model = get_type_hints(run_method).get("return")
if return_model is None:
raise PipelineDefinitionError(
f"The run method return type must be annotated in {name}"
)
# the type hint must be a subclass of DataModel
if not issubclass_safe(return_model, DataModel):
raise PipelineDefinitionError(
f"The run method must return a subclass of DataModel in {name}"
)
attrs["component_outputs"] = {
f: {
"has_default": field.is_required(),
"annotation": field.annotation,
}
for f, field in return_model.model_fields.items()
for param in sig.parameters.values()
if param.name not in ("self", "kwargs", "context_")
}
# extract returned fields from the run method return type hint
return_model = get_type_hints(run).get("return")
if return_model is None:
raise PipelineDefinitionError(
f"The run method return type must be annotated in {name}"
)
# the type hint must be a subclass of DataModel
if not issubclass_safe(return_model, DataModel):
raise PipelineDefinitionError(
f"The run method must return a subclass of DataModel in {name}"
)
attrs["component_outputs"] = {
f: {
"has_default": field.is_required(),
"annotation": field.annotation,
}
for f, field in return_model.model_fields.items()
}
return type.__new__(meta, name, bases, attrs)


Expand All @@ -76,8 +82,14 @@ class Component(abc.ABC, metaclass=ComponentMeta):
# added here for the type checker
# DO NOT CHANGE
component_inputs: dict[str, dict[str, str | bool]]
component_outputs: dict[str, dict[str, str | bool]]
component_outputs: dict[str, dict[str, str | bool | type]]

@abc.abstractmethod
async def run(self, *args: Any, **kwargs: Any) -> DataModel:
pass

async def run_with_context(
self, context_: RunContext, *args: Any, **kwargs: Any
) -> DataModel:
# default behavior to prevent a breaking change
return await self.run(*args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ParamConfig,
)
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
from neo4j_graphrag.experimental.pipeline.types import (
from neo4j_graphrag.experimental.pipeline.types.definitions import (
ComponentDefinition,
ConnectionDefinition,
PipelineDefinition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition
from neo4j_graphrag.experimental.pipeline.types.definitions import PipelineDefinition
from neo4j_graphrag.utils.logging import prettify

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from neo4j_graphrag.experimental.pipeline.config.pipeline_config import (
AbstractPipelineConfig,
)
from neo4j_graphrag.experimental.pipeline.types import ComponentDefinition
from neo4j_graphrag.experimental.pipeline.types.definitions import ComponentDefinition

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
)
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.experimental.pipeline.types import (
ConnectionDefinition,
from neo4j_graphrag.experimental.pipeline.types.definitions import ConnectionDefinition
from neo4j_graphrag.experimental.pipeline.types.schema import (
EntityInputType,
RelationInputType,
)
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/experimental/pipeline/kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.experimental.pipeline.types import (
from neo4j_graphrag.experimental.pipeline.types.schema import (
EntityInputType,
RelationInputType,
)
Expand Down
79 changes: 70 additions & 9 deletions src/neo4j_graphrag/experimental/pipeline/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,61 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Optional
import datetime
import enum
from collections.abc import Awaitable
from pydantic import BaseModel, Field

from neo4j_graphrag.experimental.pipeline.types import (
Event,
EventCallbackProtocol,
EventType,
PipelineEvent,
RunResult,
TaskEvent,
)
from typing import Any, Optional, Protocol, TYPE_CHECKING

if TYPE_CHECKING:
from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult


class EventType(enum.Enum):
PIPELINE_STARTED = "PIPELINE_STARTED"
TASK_STARTED = "TASK_STARTED"
TASK_PROGRESS = "TASK_PROGRESS"
TASK_FINISHED = "TASK_FINISHED"
PIPELINE_FINISHED = "PIPELINE_FINISHED"

@property
def is_pipeline_event(self) -> bool:
return self in [EventType.PIPELINE_STARTED, EventType.PIPELINE_FINISHED]

@property
def is_task_event(self) -> bool:
return self in [
EventType.TASK_STARTED,
EventType.TASK_PROGRESS,
EventType.TASK_FINISHED,
]


class Event(BaseModel):
event_type: EventType
run_id: str
"""Pipeline unique run_id, same as the one returned in PipelineResult after pipeline.run"""
timestamp: datetime.datetime = Field(
default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
)
message: Optional[str] = None
"""Optional information about the status"""
payload: Optional[dict[str, Any]] = None
"""Input or output data depending on the type of event"""


class PipelineEvent(Event):
pass


class TaskEvent(Event):
task_name: str
"""Name of the task as defined in pipeline.add_component"""


class EventCallbackProtocol(Protocol):
def __call__(self, event: Event) -> Awaitable[None]: ...


class EventNotifier:
Expand Down Expand Up @@ -87,3 +132,19 @@ async def notify_task_finished(
else None,
)
await self.notify(event)

async def notify_task_progress(
self,
run_id: str,
task_name: str,
message: str,
data: dict[str, Any],
) -> None:
event = TaskEvent(
event_type=EventType.TASK_PROGRESS,
run_id=run_id,
task_name=task_name,
message=message,
payload=data,
)
await self.notify(event)
Loading