48
48
49
49
class RunStatus (enum .Enum ):
50
50
UNKNOWN = "UNKNOWN"
51
- SCHEDULED = "SCHEDULED"
52
- WAITING = "WAITING"
53
51
RUNNING = "RUNNING"
54
- SKIP = "SKIP"
55
52
DONE = "DONE"
56
53
57
54
@@ -76,37 +73,6 @@ def __init__(self, name: str, component: Component):
76
73
"""
77
74
super ().__init__ (name , {})
78
75
self .component = component
79
- self .status : dict [str , RunStatus ] = {}
80
- self ._lock = asyncio .Lock ()
81
- """This lock is used to make sure we're not trying
82
- to update the status in //. This should prevent the task to
83
- be executed multiple times because the status was not known
84
- by the orchestrator.
85
- """
86
-
87
- async def set_status (self , run_id : str , status : RunStatus ) -> None :
88
- """Set a new status
89
-
90
- Args:
91
- run_id (str): Unique ID for the current pipeline run
92
- status (RunStatus): New status
93
-
94
- Raises:
95
- PipelineStatusUpdateError if the new status is not
96
- compatible with the current one.
97
- """
98
- async with self ._lock :
99
- current_status = self .status .get (run_id )
100
- if status == current_status :
101
- raise PipelineStatusUpdateError ()
102
- if status == RunStatus .RUNNING and current_status == RunStatus .DONE :
103
- # can't go back to RUNNING from DONE
104
- raise PipelineStatusUpdateError ()
105
- self .status [run_id ] = status
106
-
107
- async def read_status (self , run_id : str ) -> RunStatus :
108
- async with self ._lock :
109
- return self .status .get (run_id , RunStatus .UNKNOWN )
110
76
111
77
async def execute (self , ** kwargs : Any ) -> RunResult | None :
112
78
"""Execute the task
@@ -163,31 +129,52 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
163
129
None
164
130
"""
165
131
input_config = await self .get_input_config_for_task (task )
166
- inputs = self .get_component_inputs (task .name , input_config , data )
132
+ inputs = await self .get_component_inputs (task .name , input_config , data )
167
133
try :
168
- await task . set_status ( self . run_id , RunStatus .RUNNING )
134
+ await self . set_task_status ( task . name , RunStatus .RUNNING )
169
135
except PipelineStatusUpdateError :
170
- logger .info (
171
- f"Component { task .name } already running or done { task .status .get (self .run_id )} "
172
- )
136
+ logger .info (f"Component { task .name } already running or done" )
173
137
return None
174
138
res = await task .run (inputs )
175
- await task . set_status ( self . run_id , RunStatus .DONE )
139
+ await self . set_task_status ( task . name , RunStatus .DONE )
176
140
if res :
177
141
await self .on_task_complete (data = data , task = task , result = res )
178
142
143
+ async def set_task_status (self , task_name : str , status : RunStatus ) -> None :
144
+ """Set a new status
145
+
146
+ Args:
147
+ task_name (str): Name of the component
148
+ status (RunStatus): New status
149
+
150
+ Raises:
151
+ PipelineStatusUpdateError if the new status is not
152
+ compatible with the current one.
153
+ """
154
+ # prevent the method from being called by two concurrent async calls
155
+ async with asyncio .Lock ():
156
+ current_status = await self .get_status_for_component (task_name )
157
+ if status == current_status :
158
+ raise PipelineStatusUpdateError (f"Status is already '{ status } '" )
159
+ if status == RunStatus .RUNNING and current_status == RunStatus .DONE :
160
+ raise PipelineStatusUpdateError ("Can't go from DONE to RUNNING" )
161
+ return await self .pipeline .store .add_status_for_component (
162
+ self .run_id , task_name , status .value
163
+ )
164
+
179
165
async def on_task_complete (
180
166
self , data : dict [str , Any ], task : TaskPipelineNode , result : RunResult
181
167
) -> None :
182
168
"""When a given task is complete, it will call this method
183
169
to find the next tasks to run.
184
170
"""
185
- # first call the method for the pipeline
186
- # this is where the results can be saved
171
+ # first save this component results
187
172
res_to_save = None
188
173
if result .result :
189
174
res_to_save = result .result .model_dump ()
190
- self .add_result_for_component (task .name , res_to_save , is_final = task .is_leaf ())
175
+ await self .add_result_for_component (
176
+ task .name , res_to_save , is_final = task .is_leaf ()
177
+ )
191
178
# then get the next tasks to be executed
192
179
# and run them in //
193
180
await asyncio .gather (* [self .run_task (n , data ) async for n in self .next (task )])
@@ -200,8 +187,7 @@ async def check_dependencies_complete(self, task: TaskPipelineNode) -> None:
200
187
"""
201
188
dependencies = self .pipeline .previous_edges (task .name )
202
189
for d in dependencies :
203
- start_node = self .pipeline .get_node_by_name (d .start )
204
- d_status = await start_node .read_status (self .run_id )
190
+ d_status = await self .get_status_for_component (d .start )
205
191
if d_status != RunStatus .DONE :
206
192
logger .warning (
207
193
f"Missing dependency { d .start } for { task .name } (status: { d_status } )"
@@ -223,7 +209,7 @@ async def next(
223
209
for next_edge in possible_next :
224
210
next_node = self .pipeline .get_node_by_name (next_edge .end )
225
211
# check status
226
- next_node_status = await next_node . read_status ( self . run_id )
212
+ next_node_status = await self . get_status_for_component ( next_node . name )
227
213
if next_node_status in [RunStatus .RUNNING , RunStatus .DONE ]:
228
214
# already running
229
215
continue
@@ -251,8 +237,7 @@ async def get_input_config_for_task(self, task: TaskPipelineNode) -> dict[str, s
251
237
# make sure dependencies are satisfied
252
238
# and save the inputs defs that needs to be propagated from parent components
253
239
for prev_edge in self .pipeline .previous_edges (task .name ):
254
- prev_node = self .pipeline .get_node_by_name (prev_edge .start )
255
- prev_status = await prev_node .read_status (self .run_id )
240
+ prev_status = await self .get_status_for_component (prev_edge .start )
256
241
if prev_status != RunStatus .DONE :
257
242
logger .critical (f"Missing dependency { prev_edge .start } " )
258
243
raise PipelineMissingDependencyError (f"{ prev_edge .start } not ready" )
@@ -261,7 +246,7 @@ async def get_input_config_for_task(self, task: TaskPipelineNode) -> dict[str, s
261
246
input_config .update (** prev_edge_data )
262
247
return input_config
263
248
264
- def get_component_inputs (
249
+ async def get_component_inputs (
265
250
self ,
266
251
component_name : str ,
267
252
input_config : dict [str , Any ],
@@ -287,7 +272,7 @@ def get_component_inputs(
287
272
# component as input
288
273
component = mapping
289
274
output_param = None
290
- component_result = self .get_results_for_component (component )
275
+ component_result = await self .get_results_for_component (component )
291
276
if output_param is not None :
292
277
value = component_result .get (output_param )
293
278
else :
@@ -299,25 +284,31 @@ def get_component_inputs(
299
284
component_inputs [parameter ] = value
300
285
return component_inputs
301
286
302
- def add_result_for_component (
287
+ async def add_result_for_component (
303
288
self , name : str , result : dict [str , Any ] | None , is_final : bool = False
304
289
) -> None :
305
290
"""This is where we save the results in the result store and, optionally,
306
291
in the final result store.
307
292
"""
308
- self .pipeline .store .add_result_for_component (self .run_id , name , result )
293
+ await self .pipeline .store .add_result_for_component (self .run_id , name , result )
309
294
if is_final :
310
295
# The pipeline only returns the results
311
296
# of the leaf nodes
312
297
# TODO: make this configurable in the future.
313
- existing_results = self .pipeline .final_results .get (self .run_id ) or {}
298
+ existing_results = await self .pipeline .final_results .get (self .run_id ) or {}
314
299
existing_results [name ] = result
315
- self .pipeline .final_results .add (
300
+ await self .pipeline .final_results .add (
316
301
self .run_id , existing_results , overwrite = True
317
302
)
318
303
319
- def get_results_for_component (self , name : str ) -> Any :
320
- return self .pipeline .store .get_result_for_component (self .run_id , name )
304
+ async def get_results_for_component (self , name : str ) -> Any :
305
+ return await self .pipeline .store .get_result_for_component (self .run_id , name )
306
+
307
+ async def get_status_for_component (self , name : str ) -> RunStatus :
308
+ status = await self .pipeline .store .get_status_for_component (self .run_id , name )
309
+ if status is None :
310
+ return RunStatus .UNKNOWN
311
+ return RunStatus (status )
321
312
322
313
async def run (self , data : dict [str , Any ]) -> None :
323
314
"""Run the pipline, starting from the root nodes
@@ -500,5 +491,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:
500
491
)
501
492
return PipelineResult (
502
493
run_id = orchestrator .run_id ,
503
- result = self .final_results .get (orchestrator .run_id ),
494
+ result = await self .final_results .get (orchestrator .run_id ),
504
495
)
0 commit comments