30
30
31
31
32
32
class DefaultSessionRunner (SessionRunnerBase ):
33
- """Processes a single session's invocations"""
33
+ """Processes a single session's invocations. """
34
34
35
35
def __init__ (
36
36
self ,
@@ -40,21 +40,28 @@ def __init__(
40
40
on_node_error_callbacks : Optional [list [OnNodeError ]] = None ,
41
41
on_after_run_session_callbacks : Optional [list [OnAfterRunSession ]] = None ,
42
42
):
43
+ """
44
+ Args:
45
+ on_before_run_session_callbacks: Callbacks to run before the session starts.
46
+ on_before_run_node_callbacks: Callbacks to run before each node starts.
47
+ on_after_run_node_callbacks: Callbacks to run after each node completes.
48
+ on_node_error_callbacks: Callbacks to run when a node errors.
49
+ on_after_run_session_callbacks: Callbacks to run after the session completes.
50
+ """
51
+
43
52
self ._on_before_run_session_callbacks = on_before_run_session_callbacks or []
44
53
self ._on_before_run_node_callbacks = on_before_run_node_callbacks or []
45
54
self ._on_after_run_node_callbacks = on_after_run_node_callbacks or []
46
55
self ._on_node_error_callbacks = on_node_error_callbacks or []
47
56
self ._on_after_run_session_callbacks = on_after_run_session_callbacks or []
48
57
49
58
def start (self , services : InvocationServices , cancel_event : ThreadEvent , profiler : Optional [Profiler ] = None ):
50
- """Start the session runner"""
51
59
self ._services = services
52
60
self ._cancel_event = cancel_event
53
61
self ._profiler = profiler
54
62
55
63
def run (self , queue_item : SessionQueueItem ):
56
- """Run the graph"""
57
- # Exceptions raised outside `run_node` are handled by the processor.
64
+ # Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here.
58
65
59
66
self ._on_before_run_session (queue_item = queue_item )
60
67
@@ -78,14 +85,16 @@ def run(self, queue_item: SessionQueueItem):
78
85
79
86
if invocation is None or self ._cancel_event .is_set ():
80
87
break
88
+
81
89
self .run_node (invocation , queue_item )
90
+
91
+ # The session is complete if all invocations have been run or there is an error on the session.
82
92
if queue_item .session .is_complete () or self ._cancel_event .is_set ():
83
93
break
84
94
85
95
self ._on_after_run_session (queue_item = queue_item )
86
96
87
97
def run_node (self , invocation : BaseInvocation , queue_item : SessionQueueItem ):
88
- """Run a single node in the graph"""
89
98
try :
90
99
# Any unhandled exception in this scope is an invocation error & will fail the graph
91
100
with self ._services .performance_statistics .collect_stats (invocation , queue_item .session_id ):
@@ -110,7 +119,7 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
110
119
self ._on_after_run_node (invocation , queue_item , output )
111
120
112
121
except KeyboardInterrupt :
113
- # TODO(MM2 ): Create an event for this
122
+ # TODO(psyche ): This is expected to be caught in the main thread. Do we need to catch this here?
114
123
pass
115
124
except CanceledException :
116
125
# When the user cancels the graph, we first set the cancel event. The event is checked
@@ -137,6 +146,8 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
137
146
)
138
147
139
148
def _on_before_run_session (self , queue_item : SessionQueueItem ) -> None :
149
+ """Run before a session is executed"""
150
+
140
151
# If profiling is enabled, start the profiler
141
152
if self ._profiler is not None :
142
153
self ._profiler .start (profile_id = queue_item .session_id )
@@ -145,6 +156,8 @@ def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
145
156
callback (queue_item = queue_item )
146
157
147
158
def _on_after_run_session (self , queue_item : SessionQueueItem ) -> None :
159
+ """Run after a session is executed"""
160
+
148
161
# If we are profiling, stop the profiler and dump the profile & stats
149
162
if self ._profiler is not None :
150
163
profile_path = self ._profiler .stop ()
@@ -156,7 +169,8 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
156
169
# Update the queue item with the completed session
157
170
self ._services .session_queue .set_queue_item_session (queue_item .item_id , queue_item .session )
158
171
159
- # Send complete event
172
+ # TODO(psyche): This feels jumbled - we should review separation of concerns here.
173
+ # Send complete event. The events service will receive this and update the queue item's status.
160
174
self ._services .events .emit_graph_execution_complete (
161
175
queue_batch_id = queue_item .batch_id ,
162
176
queue_item_id = queue_item .item_id ,
@@ -175,6 +189,7 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
175
189
176
190
def _on_before_run_node (self , invocation : BaseInvocation , queue_item : SessionQueueItem ):
177
191
"""Run before a node is executed"""
192
+
178
193
# Send starting event
179
194
self ._services .events .emit_invocation_started (
180
195
queue_batch_id = queue_item .batch_id ,
@@ -192,6 +207,7 @@ def _on_after_run_node(
192
207
self , invocation : BaseInvocation , queue_item : SessionQueueItem , output : BaseInvocationOutput
193
208
):
194
209
"""Run after a node is executed"""
210
+
195
211
# Send complete event on successful runs
196
212
self ._services .events .emit_invocation_complete (
197
213
queue_batch_id = queue_item .batch_id ,
@@ -214,6 +230,8 @@ def _on_node_error(
214
230
error_message : str ,
215
231
error_traceback : str ,
216
232
):
233
+ """Run when a node errors"""
234
+
217
235
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
218
236
node_error = f"{ error_type } : { error_message } "
219
237
queue_item .session .set_node_error (invocation .id , node_error )
@@ -356,17 +374,17 @@ def _process(
356
374
resume_event : ThreadEvent ,
357
375
cancel_event : ThreadEvent ,
358
376
):
359
- # Outermost processor try block; any unhandled exception is a fatal processor error
360
377
try :
378
+ # Any unhandled exception in this block is a fatal processor error and will stop the processor.
361
379
self ._thread_semaphore .acquire ()
362
380
stop_event .clear ()
363
381
resume_event .set ()
364
382
cancel_event .clear ()
365
383
366
384
while not stop_event .is_set ():
367
385
poll_now_event .clear ()
368
- # Middle processor try block; any unhandled exception is a non-fatal processor error
369
386
try :
387
+ # Any unhandled exception in this block is a nonfatal processor error and will be handled.
370
388
# If we are paused, wait for resume event
371
389
resume_event .wait ()
372
390
0 commit comments