Skip to content

Commit e365d35

Browse files
docs(processor): update docstrings, comments
1 parent aa329ea commit e365d35

File tree

2 files changed

+97
-20
lines changed

2 files changed

+97
-20
lines changed

invokeai/app/services/session_processor/session_processor_base.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,33 @@ class SessionRunnerBase(ABC):
1616

1717
@abstractmethod
1818
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
19-
"""Starts the session runner"""
19+
"""Starts the session runner.
20+
21+
Args:
22+
services: The invocation services.
23+
cancel_event: The cancel event.
24+
profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session
25+
stats will be still be recorded and logged when profiling is disabled.
26+
"""
2027
pass
2128

2229
@abstractmethod
2330
def run(self, queue_item: SessionQueueItem) -> None:
24-
"""Runs the session"""
31+
"""Runs a session.
32+
33+
Args:
34+
queue_item: The session to run.
35+
"""
2536
pass
2637

2738
@abstractmethod
2839
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
29-
"""Runs an already prepared node on the session"""
40+
"""Run a single node in the graph.
41+
42+
Args:
43+
invocation: The invocation to run.
44+
queue_item: The session queue item.
45+
"""
3046
pass
3147

3248

@@ -56,13 +72,25 @@ def get_status(self) -> SessionProcessorStatus:
5672

5773

5874
class OnBeforeRunNode(Protocol):
59-
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> bool: ...
75+
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
76+
"""Callback to run before executing a node.
77+
78+
Args:
79+
invocation: The invocation that will be executed.
80+
queue_item: The session queue item.
81+
"""
82+
...
6083

6184

6285
class OnAfterRunNode(Protocol):
63-
def __call__(
64-
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
65-
) -> bool: ...
86+
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None:
87+
"""Callback to run before executing a node.
88+
89+
Args:
90+
invocation: The invocation that was executed.
91+
queue_item: The session queue item.
92+
"""
93+
...
6694

6795

6896
class OnNodeError(Protocol):
@@ -73,15 +101,37 @@ def __call__(
73101
error_type: str,
74102
error_message: str,
75103
error_traceback: str,
76-
) -> bool: ...
104+
) -> None:
105+
"""Callback to run when a node has an error.
106+
107+
Args:
108+
invocation: The invocation that errored.
109+
queue_item: The session queue item.
110+
error_type: The type of error, e.g. "ValueError".
111+
error_message: The error message, e.g. "Invalid value".
112+
error_traceback: The stringified error traceback.
113+
"""
114+
...
77115

78116

79117
class OnBeforeRunSession(Protocol):
80-
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
118+
def __call__(self, queue_item: SessionQueueItem) -> None:
119+
"""Callback to run before executing a session.
120+
121+
Args:
122+
queue_item: The session queue item.
123+
"""
124+
...
81125

82126

83127
class OnAfterRunSession(Protocol):
84-
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
128+
def __call__(self, queue_item: SessionQueueItem) -> None:
129+
"""Callback to run after executing a session.
130+
131+
Args:
132+
queue_item: The session queue item.
133+
"""
134+
...
85135

86136

87137
class OnNonFatalProcessorError(Protocol):
@@ -91,4 +141,13 @@ def __call__(
91141
error_type: str,
92142
error_message: str,
93143
error_traceback: str,
94-
) -> bool: ...
144+
) -> None:
145+
"""Callback to run when a non-fatal error occurs in the processor.
146+
147+
Args:
148+
queue_item: The session queue item, if one was being executed when the error occurred.
149+
error_type: The type of error, e.g. "ValueError".
150+
error_message: The error message, e.g. "Invalid value".
151+
error_traceback: The stringified error traceback.
152+
"""
153+
...

invokeai/app/services/session_processor/session_processor_default.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131

3232
class DefaultSessionRunner(SessionRunnerBase):
33-
"""Processes a single session's invocations"""
33+
"""Processes a single session's invocations."""
3434

3535
def __init__(
3636
self,
@@ -40,21 +40,28 @@ def __init__(
4040
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
4141
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
4242
):
43+
"""
44+
Args:
45+
on_before_run_session_callbacks: Callbacks to run before the session starts.
46+
on_before_run_node_callbacks: Callbacks to run before each node starts.
47+
on_after_run_node_callbacks: Callbacks to run after each node completes.
48+
on_node_error_callbacks: Callbacks to run when a node errors.
49+
on_after_run_session_callbacks: Callbacks to run after the session completes.
50+
"""
51+
4352
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
4453
self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
4554
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
4655
self._on_node_error_callbacks = on_node_error_callbacks or []
4756
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
4857

4958
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
50-
"""Start the session runner"""
5159
self._services = services
5260
self._cancel_event = cancel_event
5361
self._profiler = profiler
5462

5563
def run(self, queue_item: SessionQueueItem):
56-
"""Run the graph"""
57-
# Exceptions raised outside `run_node` are handled by the processor.
64+
# Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here.
5865

5966
self._on_before_run_session(queue_item=queue_item)
6067

@@ -78,14 +85,16 @@ def run(self, queue_item: SessionQueueItem):
7885

7986
if invocation is None or self._cancel_event.is_set():
8087
break
88+
8189
self.run_node(invocation, queue_item)
90+
91+
# The session is complete if all invocations have been run or there is an error on the session.
8292
if queue_item.session.is_complete() or self._cancel_event.is_set():
8393
break
8494

8595
self._on_after_run_session(queue_item=queue_item)
8696

8797
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
88-
"""Run a single node in the graph"""
8998
try:
9099
# Any unhandled exception in this scope is an invocation error & will fail the graph
91100
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
@@ -110,7 +119,7 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
110119
self._on_after_run_node(invocation, queue_item, output)
111120

112121
except KeyboardInterrupt:
113-
# TODO(MM2): Create an event for this
122+
# TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here?
114123
pass
115124
except CanceledException:
116125
# When the user cancels the graph, we first set the cancel event. The event is checked
@@ -137,6 +146,8 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
137146
)
138147

139148
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
149+
"""Run before a session is executed"""
150+
140151
# If profiling is enabled, start the profiler
141152
if self._profiler is not None:
142153
self._profiler.start(profile_id=queue_item.session_id)
@@ -145,6 +156,8 @@ def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
145156
callback(queue_item=queue_item)
146157

147158
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
159+
"""Run after a session is executed"""
160+
148161
# If we are profiling, stop the profiler and dump the profile & stats
149162
if self._profiler is not None:
150163
profile_path = self._profiler.stop()
@@ -156,7 +169,8 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
156169
# Update the queue item with the completed session
157170
self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
158171

159-
# Send complete event
172+
# TODO(psyche): This feels jumbled - we should review separation of concerns here.
173+
# Send complete event. The events service will receive this and update the queue item's status.
160174
self._services.events.emit_graph_execution_complete(
161175
queue_batch_id=queue_item.batch_id,
162176
queue_item_id=queue_item.item_id,
@@ -175,6 +189,7 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
175189

176190
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
177191
"""Run before a node is executed"""
192+
178193
# Send starting event
179194
self._services.events.emit_invocation_started(
180195
queue_batch_id=queue_item.batch_id,
@@ -192,6 +207,7 @@ def _on_after_run_node(
192207
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
193208
):
194209
"""Run after a node is executed"""
210+
195211
# Send complete event on successful runs
196212
self._services.events.emit_invocation_complete(
197213
queue_batch_id=queue_item.batch_id,
@@ -214,6 +230,8 @@ def _on_node_error(
214230
error_message: str,
215231
error_traceback: str,
216232
):
233+
"""Run when a node errors"""
234+
217235
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
218236
node_error = f"{error_type}: {error_message}"
219237
queue_item.session.set_node_error(invocation.id, node_error)
@@ -356,17 +374,17 @@ def _process(
356374
resume_event: ThreadEvent,
357375
cancel_event: ThreadEvent,
358376
):
359-
# Outermost processor try block; any unhandled exception is a fatal processor error
360377
try:
378+
# Any unhandled exception in this block is a fatal processor error and will stop the processor.
361379
self._thread_semaphore.acquire()
362380
stop_event.clear()
363381
resume_event.set()
364382
cancel_event.clear()
365383

366384
while not stop_event.is_set():
367385
poll_now_event.clear()
368-
# Middle processor try block; any unhandled exception is a non-fatal processor error
369386
try:
387+
# Any unhandled exception in this block is a nonfatal processor error and will be handled.
370388
# If we are paused, wait for resume event
371389
resume_event.wait()
372390

0 commit comments

Comments
 (0)