@@ -38,17 +38,17 @@ class DefaultSessionRunner(SessionRunnerBase):
38
38
39
39
def __init__ (
40
40
self ,
41
- on_before_run_session : Optional [OnBeforeRunSession ] = None ,
42
- on_before_run_node : Optional [OnBeforeRunNode ] = None ,
43
- on_after_run_node : Optional [OnAfterRunNode ] = None ,
44
- on_node_error : Optional [OnNodeError ] = None ,
45
- on_after_run_session : Optional [OnAfterRunSession ] = None ,
41
+ on_before_run_session_callbacks : Optional [list [ OnBeforeRunSession ] ] = None ,
42
+ on_before_run_node_callbacks : Optional [list [ OnBeforeRunNode ] ] = None ,
43
+ on_after_run_node_callbacks : Optional [list [ OnAfterRunNode ] ] = None ,
44
+ on_node_error_callbacks : Optional [list [ OnNodeError ] ] = None ,
45
+ on_after_run_session_callbacks : Optional [list [ OnAfterRunSession ] ] = None ,
46
46
):
47
- self .on_before_run_session = on_before_run_session
48
- self .on_before_run_node = on_before_run_node
49
- self .on_after_run_node = on_after_run_node
50
- self .on_node_error = on_node_error
51
- self .on_after_run_session = on_after_run_session
47
+ self ._on_before_run_session_callbacks = on_before_run_session_callbacks or []
48
+ self ._on_before_run_node_callbacks = on_before_run_node_callbacks or []
49
+ self ._on_after_run_node_callbacks = on_after_run_node_callbacks or []
50
+ self ._on_node_error_callbacks = on_node_error_callbacks or []
51
+ self ._on_after_run_session_callbacks = on_after_run_session_callbacks or []
52
52
53
53
def start (self , services : InvocationServices , cancel_event : ThreadEvent , profiler : Optional [Profiler ] = None ):
54
54
"""Start the session runner"""
@@ -77,8 +77,8 @@ def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
77
77
if self ._profiler is not None :
78
78
self ._profiler .start (profile_id = queue_item .session_id )
79
79
80
- if self .on_before_run_session :
81
- self . on_before_run_session (queue_item = queue_item )
80
+ for callback in self ._on_before_run_session_callbacks :
81
+ callback (queue_item = queue_item )
82
82
83
83
def _on_after_run_session (self , queue_item : SessionQueueItem ) -> None :
84
84
# If we are profiling, stop the profiler and dump the profile & stats
@@ -103,8 +103,8 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
103
103
self ._services .performance_statistics .log_stats (queue_item .session .id )
104
104
self ._services .performance_statistics .reset_stats ()
105
105
106
- if self .on_after_run_session :
107
- self . on_after_run_session (queue_item )
106
+ for callback in self ._on_after_run_session_callbacks :
107
+ callback (queue_item )
108
108
109
109
def _on_before_run_node (self , invocation : BaseInvocation , queue_item : SessionQueueItem ):
110
110
"""Run before a node is executed"""
@@ -117,9 +117,9 @@ def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQue
117
117
node = invocation .model_dump (),
118
118
source_node_id = queue_item .session .prepared_source_mapping [invocation .id ],
119
119
)
120
- # And run lifecycle callbacks
121
- if self . on_before_run_node is not None :
122
- self . on_before_run_node (invocation , queue_item )
120
+
121
+ for callback in self . _on_before_run_node_callbacks :
122
+ callback (invocation , queue_item )
123
123
124
124
def _on_after_run_node (
125
125
self , invocation : BaseInvocation , queue_item : SessionQueueItem , outputs : BaseInvocationOutput
@@ -135,9 +135,9 @@ def _on_after_run_node(
135
135
source_node_id = queue_item .session .prepared_source_mapping [invocation .id ],
136
136
result = outputs .model_dump (),
137
137
)
138
- # And run lifecycle callbacks
139
- if self . on_after_run_node is not None :
140
- self . on_after_run_node (invocation , queue_item , outputs )
138
+
139
+ for callback in self . _on_after_run_node_callbacks :
140
+ callback (invocation , queue_item , outputs )
141
141
142
142
def _on_node_error (
143
143
self ,
@@ -169,8 +169,8 @@ def _on_node_error(
169
169
project_id = None ,
170
170
)
171
171
172
- if self . on_node_error is not None :
173
- self . on_node_error (invocation , queue_item , exc_type , exc_value , exc_traceback )
172
+ for callback in self . _on_node_error_callbacks :
173
+ callback (invocation , queue_item , exc_type , exc_value , exc_traceback )
174
174
175
175
def run_node (self , invocation : BaseInvocation , queue_item : SessionQueueItem ):
176
176
"""Run a single node in the graph"""
@@ -213,6 +213,7 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
213
213
# loop go to its next iteration, and the cancel event will be handled correctly.
214
214
pass
215
215
except Exception as e :
216
+ # Must extract the exception traceback here to not lose its stacktrace when we change scope
216
217
exc_type = type (e )
217
218
exc_value = e
218
219
exc_traceback = e .__traceback__
@@ -224,14 +225,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
224
225
def __init__ (
225
226
self ,
226
227
session_runner : Optional [SessionRunnerBase ] = None ,
227
- on_non_fatal_processor_error : Optional [OnNonFatalProcessorError ] = None ,
228
+ on_non_fatal_processor_error_callbacks : Optional [list [ OnNonFatalProcessorError ] ] = None ,
228
229
thread_limit : int = 1 ,
229
230
polling_interval : int = 1 ,
230
231
) -> None :
231
232
super ().__init__ ()
232
233
233
234
self .session_runner = session_runner if session_runner else DefaultSessionRunner ()
234
- self .on_non_fatal_processor_error = on_non_fatal_processor_error
235
+ self ._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or []
235
236
self ._thread_limit = thread_limit
236
237
self ._polling_interval = polling_interval
237
238
@@ -249,8 +250,8 @@ def _on_non_fatal_processor_error(
249
250
if queue_item is not None :
250
251
self ._invoker .services .session_queue .cancel_queue_item (queue_item .item_id , error = stacktrace )
251
252
252
- if self .on_non_fatal_processor_error :
253
- self . on_non_fatal_processor_error (exc_type , exc_value , exc_traceback , queue_item )
253
+ for callback in self ._on_non_fatal_processor_error_callbacks :
254
+ callback (exc_type , exc_value , exc_traceback , queue_item )
254
255
255
256
def start (self , invoker : Invoker ) -> None :
256
257
self ._invoker : Invoker = invoker
@@ -376,6 +377,7 @@ def _process(
376
377
self .session_runner .run (queue_item = self ._queue_item )
377
378
378
379
except Exception as e :
380
+ # Must extract the exception traceback here to not lose its stacktrace when we change scope
379
381
exc_type = type (e )
380
382
exc_value = e
381
383
exc_traceback = e .__traceback__
0 commit comments