Skip to content

Commit a662765

Browse files
committed
Expose context from the orchestrator to the components
1 parent 8668511 commit a662765

File tree

23 files changed

+422
-188
lines changed

23 files changed

+422
-188
lines changed

examples/build_graph/simple_kg_builder_from_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from neo4j_graphrag.embeddings import OpenAIEmbeddings
1515
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
1616
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
17-
from neo4j_graphrag.experimental.pipeline.types import (
17+
from neo4j_graphrag.experimental.pipeline.types.schema import (
1818
EntityInputType,
1919
RelationInputType,
2020
)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""This example demonstrates how to use event callback to receive notifications
2+
about the component progress.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import asyncio
8+
import logging
9+
from typing import Any
10+
11+
from neo4j_graphrag.experimental.pipeline import Pipeline, Component, DataModel
12+
from neo4j_graphrag.experimental.pipeline.notification import Event, EventType
13+
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
14+
15+
logger = logging.getLogger(__name__)
16+
logging.basicConfig()
17+
logger.setLevel(logging.INFO)
18+
19+
20+
class BatchComponentResult(DataModel):
21+
result: list[int]
22+
23+
24+
class MultiplicationComponent(Component):
25+
def __init__(self, f: int) -> None:
26+
self.f = f
27+
28+
async def run(self, numbers: list[int]) -> BatchComponentResult:
29+
return BatchComponentResult(result=[])
30+
31+
async def multiply_number(
32+
self, context_: RunContext, number: int,
33+
) -> int:
34+
await context_.notify(
35+
message=f"Processing number {number}",
36+
data={"number_processed": number},
37+
)
38+
return self.f * number
39+
40+
async def run_with_context(
41+
self,
42+
context_: RunContext,
43+
numbers: list[int],
44+
**kwargs: Any,
45+
) -> BatchComponentResult:
46+
result = await asyncio.gather(
47+
*[
48+
self.multiply_number(
49+
context_,
50+
number,
51+
)
52+
for number in numbers
53+
]
54+
)
55+
return BatchComponentResult(result=result)
56+
57+
58+
async def event_handler(event: Event) -> None:
59+
"""Function can do anything about the event,
60+
here we're just logging it if it's a pipeline-level event.
61+
"""
62+
if event.event_type == EventType.TASK_PROGRESS:
63+
logger.warning(event)
64+
else:
65+
logger.info(event)
66+
67+
68+
async def main() -> None:
69+
""" """
70+
pipe = Pipeline(
71+
callback=event_handler,
72+
)
73+
# define the components
74+
pipe.add_component(
75+
MultiplicationComponent(f=2),
76+
"multiply_by_2",
77+
)
78+
pipe.add_component(
79+
MultiplicationComponent(f=10),
80+
"multiply_by_10",
81+
)
82+
# define the execution order of component
83+
# and how the output of previous components must be used
84+
pipe.connect(
85+
"multiply_by_2",
86+
"multiply_by_10",
87+
input_config={"numbers": "multiply_by_2.result"},
88+
)
89+
# user input:
90+
pipe_inputs_1 = {
91+
"multiply_by_2": {
92+
"numbers": [1, 2, 5, 4],
93+
},
94+
}
95+
pipe_inputs_2 = {
96+
"multiply_by_2": {
97+
"numbers": [3, 10, 1],
98+
}
99+
}
100+
# run the pipeline
101+
await asyncio.gather(
102+
pipe.run(pipe_inputs_1),
103+
pipe.run(pipe_inputs_2),
104+
)
105+
106+
107+
if __name__ == "__main__":
108+
asyncio.run(main())

examples/customize/build_graph/pipeline/pipeline_with_notifications.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from neo4j_graphrag.experimental.pipeline import Pipeline
1717
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
18-
from neo4j_graphrag.experimental.pipeline.types import Event
18+
from neo4j_graphrag.experimental.pipeline.notification import Event
1919

2020
logger = logging.getLogger(__name__)
2121
logging.basicConfig()

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from neo4j_graphrag.exceptions import SchemaValidationError
2323
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
24-
from neo4j_graphrag.experimental.pipeline.types import (
24+
from neo4j_graphrag.experimental.pipeline.types.schema import (
2525
EntityInputType,
2626
RelationInputType,
2727
)

src/neo4j_graphrag/experimental/pipeline/component.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from pydantic import BaseModel
2222

23+
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
2324
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
2425

2526

@@ -35,34 +36,39 @@ def __new__(
3536
) -> type:
3637
# extract required inputs and outputs from the run method signature
3738
run_method = attrs.get("run")
38-
if run_method is not None:
39-
sig = inspect.signature(run_method)
40-
attrs["component_inputs"] = {
41-
param.name: {
42-
"has_default": param.default != inspect.Parameter.empty,
43-
"annotation": param.annotation,
44-
}
45-
for param in sig.parameters.values()
46-
if param.name not in ("self", "kwargs")
39+
run_context_method = attrs.get("run_with_context")
40+
run = run_context_method or run_method
41+
if run is None:
42+
raise RuntimeError(
43+
f"Either 'run' or 'run_with_context' must be implemented in component: '{name}'"
44+
)
45+
sig = inspect.signature(run)
46+
attrs["component_inputs"] = {
47+
param.name: {
48+
"has_default": param.default != inspect.Parameter.empty,
49+
"annotation": param.annotation,
4750
}
48-
# extract returned fields from the run method return type hint
49-
return_model = get_type_hints(run_method).get("return")
50-
if return_model is None:
51-
raise PipelineDefinitionError(
52-
f"The run method return type must be annotated in {name}"
53-
)
54-
# the type hint must be a subclass of DataModel
55-
if not issubclass(return_model, DataModel):
56-
raise PipelineDefinitionError(
57-
f"The run method must return a subclass of DataModel in {name}"
58-
)
59-
attrs["component_outputs"] = {
60-
f: {
61-
"has_default": field.is_required(),
62-
"annotation": field.annotation,
63-
}
64-
for f, field in return_model.model_fields.items()
51+
for param in sig.parameters.values()
52+
if param.name not in ("self", "kwargs", "context_")
53+
}
54+
# extract returned fields from the run method return type hint
55+
return_model = get_type_hints(run).get("return")
56+
if return_model is None:
57+
raise PipelineDefinitionError(
58+
f"The run method return type must be annotated in {name}"
59+
)
60+
# the type hint must be a subclass of DataModel
61+
if not issubclass(return_model, DataModel):
62+
raise PipelineDefinitionError(
63+
f"The run method must return a subclass of DataModel in {name}"
64+
)
65+
attrs["component_outputs"] = {
66+
f: {
67+
"has_default": field.is_required(),
68+
"annotation": field.annotation,
6569
}
70+
for f, field in return_model.model_fields.items()
71+
}
6672
return type.__new__(meta, name, bases, attrs)
6773

6874

@@ -80,3 +86,9 @@ class Component(abc.ABC, metaclass=ComponentMeta):
8086
@abc.abstractmethod
8187
async def run(self, *args: Any, **kwargs: Any) -> DataModel:
8288
pass
89+
90+
async def run_with_context(
91+
self, context_: RunContext, *args: Any, **kwargs: Any
92+
) -> DataModel:
93+
# default behavior to prevent a breaking change
94+
return await self.run(*args, **kwargs)

src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
ParamConfig,
3232
)
3333
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
34-
from neo4j_graphrag.experimental.pipeline.types import (
34+
from neo4j_graphrag.experimental.pipeline.types.definitions import (
3535
ComponentDefinition,
3636
ConnectionDefinition,
3737
PipelineDefinition,

src/neo4j_graphrag/experimental/pipeline/config/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
)
4848
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
4949
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
50-
from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition
50+
from neo4j_graphrag.experimental.pipeline.types.definitions import PipelineDefinition
5151
from neo4j_graphrag.utils.logging import prettify
5252

5353
logger = logging.getLogger(__name__)

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from neo4j_graphrag.experimental.pipeline.config.pipeline_config import (
1919
AbstractPipelineConfig,
2020
)
21-
from neo4j_graphrag.experimental.pipeline.types import ComponentDefinition
21+
from neo4j_graphrag.experimental.pipeline.types.definitions import ComponentDefinition
2222

2323
logger = logging.getLogger(__name__)
2424

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
)
4545
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
4646
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
47-
from neo4j_graphrag.experimental.pipeline.types import (
48-
ConnectionDefinition,
47+
from neo4j_graphrag.experimental.pipeline.types.definitions import ConnectionDefinition
48+
from neo4j_graphrag.experimental.pipeline.types.schema import (
4949
EntityInputType,
5050
RelationInputType,
5151
)

src/neo4j_graphrag/experimental/pipeline/kg_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
3535
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
36-
from neo4j_graphrag.experimental.pipeline.types import (
36+
from neo4j_graphrag.experimental.pipeline.types.schema import (
3737
EntityInputType,
3838
RelationInputType,
3939
)

src/neo4j_graphrag/experimental/pipeline/notification.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,61 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import Any, Optional
17+
import datetime
18+
import enum
19+
from collections.abc import Awaitable
20+
from pydantic import BaseModel, Field
1821

19-
from neo4j_graphrag.experimental.pipeline.types import (
20-
Event,
21-
EventCallbackProtocol,
22-
EventType,
23-
PipelineEvent,
24-
RunResult,
25-
TaskEvent,
26-
)
22+
from typing import Any, Optional, Protocol, TYPE_CHECKING
23+
24+
if TYPE_CHECKING:
25+
from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult
26+
27+
28+
class EventType(enum.Enum):
29+
PIPELINE_STARTED = "PIPELINE_STARTED"
30+
TASK_STARTED = "TASK_STARTED"
31+
TASK_PROGRESS = "TASK_PROGRESS"
32+
TASK_FINISHED = "TASK_FINISHED"
33+
PIPELINE_FINISHED = "PIPELINE_FINISHED"
34+
35+
@property
36+
def is_pipeline_event(self) -> bool:
37+
return self in [EventType.PIPELINE_STARTED, EventType.PIPELINE_FINISHED]
38+
39+
@property
40+
def is_task_event(self) -> bool:
41+
return self in [
42+
EventType.TASK_STARTED,
43+
EventType.TASK_PROGRESS,
44+
EventType.TASK_FINISHED,
45+
]
46+
47+
48+
class Event(BaseModel):
49+
event_type: EventType
50+
run_id: str
51+
"""Pipeline unique run_id, same as the one returned in PipelineResult after pipeline.run"""
52+
timestamp: datetime.datetime = Field(
53+
default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
54+
)
55+
message: Optional[str] = None
56+
"""Optional information about the status"""
57+
payload: Optional[dict[str, Any]] = None
58+
"""Input or output data depending on the type of event"""
59+
60+
61+
class PipelineEvent(Event):
62+
pass
63+
64+
65+
class TaskEvent(Event):
66+
task_name: str
67+
"""Name of the task as defined in pipeline.add_component"""
68+
69+
70+
class EventCallbackProtocol(Protocol):
71+
def __call__(self, event: Event) -> Awaitable[None]: ...
2772

2873

2974
class EventNotifier:
@@ -87,3 +132,19 @@ async def notify_task_finished(
87132
else None,
88133
)
89134
await self.notify(event)
135+
136+
async def notify_task_progress(
137+
self,
138+
run_id: str,
139+
task_name: str,
140+
message: str,
141+
data: dict[str, Any],
142+
) -> None:
143+
event = TaskEvent(
144+
event_type=EventType.TASK_PROGRESS,
145+
run_id=run_id,
146+
task_name=task_name,
147+
message=message,
148+
payload=data,
149+
)
150+
await self.notify(event)

0 commit comments

Comments
 (0)