15
15
from __future__ import annotations
16
16
17
17
import asyncio
18
+ import enum
18
19
import logging
19
20
import uuid
20
21
import warnings
22
+ import datetime
23
+
24
+ from pydantic import BaseModel , Field
21
25
from typing import TYPE_CHECKING , Any , AsyncGenerator
22
26
23
27
from neo4j_graphrag .experimental .pipeline .exceptions import (
35
39
logger = logging .getLogger (__name__ )
36
40
37
41
42
+ class ResultType (enum .Enum ):
43
+ TASK_CHECKPOINT = "TASK_CHECKPOINT"
44
+ TASK_FINISHED = "TASK_FINISHED"
45
+ PIPELINE_FINISHED = "PIPELINE_FINISHED"
46
+
47
+
48
+ class Result (BaseModel ):
49
+ result_type : ResultType
50
+ data : Any
51
+ timestamp : datetime .datetime = Field (
52
+ default_factory = lambda : datetime .datetime .now (datetime .timezone .utc )
53
+ )
54
+
55
+
38
56
class Orchestrator :
39
57
"""Orchestrate a pipeline.
40
58
@@ -53,17 +71,7 @@ def __init__(self, pipeline: "Pipeline"):
53
71
self .event_notifier = EventNotifier (pipeline .callback )
54
72
self .run_id = str (uuid .uuid4 ())
55
73
56
- async def run_task (self , task : TaskPipelineNode , data : dict [str , Any ]) -> None :
57
- """Get inputs and run a specific task. Once the task is done,
58
- calls the on_task_complete method.
59
-
60
- Args:
61
- task (TaskPipelineNode): The task to be run
62
- data (dict[str, Any]): The pipeline input data
63
-
64
- Returns:
65
- None
66
- """
74
+ async def run_single_task (self , task : TaskPipelineNode , data : dict [str , Any ]) -> Result :
67
75
param_mapping = self .get_input_config_for_task (task )
68
76
inputs = await self .get_component_inputs (task .name , param_mapping , data )
69
77
try :
@@ -72,13 +80,31 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
72
80
logger .debug (
73
81
f"ORCHESTRATOR: TASK ABORTED: { task .name } is already running or done, aborting"
74
82
)
75
- return None
83
+ raise StopAsyncIteration ()
76
84
await self .event_notifier .notify_task_started (self .run_id , task .name , inputs )
77
85
res = await task .run (inputs )
78
86
await self .set_task_status (task .name , RunStatus .DONE )
79
87
await self .event_notifier .notify_task_finished (self .run_id , task .name , res )
80
88
if res :
81
- await self .on_task_complete (data = data , task = task , result = res )
89
+ await self .save_results (task = task , result = res )
90
+ return Result (result_type = ResultType .TASK_FINISHED , data = res .result )
91
+
92
+ async def run_task (self , task : TaskPipelineNode , data : dict [str , Any ]) -> AsyncGenerator [Result , None ]:
93
+ """Get inputs and run a specific task. Once the task is done,
94
+ calls the on_task_complete method.
95
+
96
+ Args:
97
+ task (TaskPipelineNode): The task to be run
98
+ data (dict[str, Any]): The pipeline input data
99
+
100
+ Returns:
101
+ None
102
+ """
103
+ yield await self .run_single_task (task , data )
104
+ # then get the next tasks to be executed and run them
105
+ async for n in self .next (task ):
106
+ async for res in self .run_task (n , data ):
107
+ yield res
82
108
83
109
async def set_task_status (self , task_name : str , status : RunStatus ) -> None :
84
110
"""Set a new status
@@ -102,8 +128,8 @@ async def set_task_status(self, task_name: str, status: RunStatus) -> None:
102
128
self .run_id , task_name , status .value
103
129
)
104
130
105
- async def on_task_complete (
106
- self , data : dict [ str , Any ], task : TaskPipelineNode , result : RunResult
131
+ async def save_results (
132
+ self , task : TaskPipelineNode , result : RunResult
107
133
) -> None :
108
134
"""When a given task is complete, it will call this method
109
135
to find the next tasks to run.
@@ -115,9 +141,6 @@ async def on_task_complete(
115
141
await self .add_result_for_component (
116
142
task .name , res_to_save , is_final = task .is_leaf ()
117
143
)
118
- # then get the next tasks to be executed
119
- # and run them in //
120
- await asyncio .gather (* [self .run_task (n , data ) async for n in self .next (task )])
121
144
122
145
async def check_dependencies_complete (self , task : TaskPipelineNode ) -> None :
123
146
"""Check that all parent tasks are complete.
@@ -257,3 +280,14 @@ async def run(self, data: dict[str, Any]) -> None:
257
280
await self .event_notifier .notify_pipeline_finished (
258
281
self .run_id , await self .pipeline .get_final_results (self .run_id )
259
282
)
283
+
284
+ async def run_step_by_step (self , data : dict [str , Any ]) -> AsyncGenerator [Result , None ]:
285
+ """Run the pipline, starting from the root nodes
286
+ (node without any parent). Then the callback on_task_complete
287
+ will handle the task dependencies.
288
+ """
289
+ for root in self .pipeline .roots ():
290
+ async for res in self .run_task (root , data ):
291
+ yield res
292
+ # tasks = [self.run_task(root, data) for root in self.pipeline.roots()]
293
+ # await asyncio.gather(*tasks)
0 commit comments