Skip to content

Commit 7eb84c9

Browse files
feat(app): support multiple processor lifecycle callbacks
1 parent c117ffd commit 7eb84c9

File tree

2 files changed

+34
-32
lines changed

2 files changed

+34
-32
lines changed

invokeai/app/api/dependencies.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,13 @@ def on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item=
130130

131131
session_processor = DefaultSessionProcessor(
132132
DefaultSessionRunner(
133-
on_before_run_session=on_before_run_session,
134-
on_before_run_node=on_before_run_node,
135-
on_after_run_node=on_after_run_node,
136-
on_node_error=on_node_error,
137-
on_after_run_session=on_after_run_session,
133+
on_before_run_session_callbacks=[on_before_run_session],
134+
on_before_run_node_callbacks=[on_before_run_node],
135+
on_after_run_node_callbacks=[on_after_run_node],
136+
on_node_error_callbacks=[on_node_error],
137+
on_after_run_session_callbacks=[on_after_run_session],
138138
),
139-
on_non_fatal_processor_error,
139+
on_non_fatal_processor_error_callbacks=[on_non_fatal_processor_error],
140140
)
141141
session_queue = SqliteSessionQueue(db=db)
142142
urls = LocalUrlService()

invokeai/app/services/session_processor/session_processor_default.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@ class DefaultSessionRunner(SessionRunnerBase):
3838

3939
def __init__(
4040
self,
41-
on_before_run_session: Optional[OnBeforeRunSession] = None,
42-
on_before_run_node: Optional[OnBeforeRunNode] = None,
43-
on_after_run_node: Optional[OnAfterRunNode] = None,
44-
on_node_error: Optional[OnNodeError] = None,
45-
on_after_run_session: Optional[OnAfterRunSession] = None,
41+
on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None,
42+
on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None,
43+
on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None,
44+
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
45+
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
4646
):
47-
self.on_before_run_session = on_before_run_session
48-
self.on_before_run_node = on_before_run_node
49-
self.on_after_run_node = on_after_run_node
50-
self.on_node_error = on_node_error
51-
self.on_after_run_session = on_after_run_session
47+
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
48+
self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
49+
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
50+
self._on_node_error_callbacks = on_node_error_callbacks or []
51+
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
5252

5353
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
5454
"""Start the session runner"""
@@ -77,8 +77,8 @@ def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
7777
if self._profiler is not None:
7878
self._profiler.start(profile_id=queue_item.session_id)
7979

80-
if self.on_before_run_session:
81-
self.on_before_run_session(queue_item=queue_item)
80+
for callback in self._on_before_run_session_callbacks:
81+
callback(queue_item=queue_item)
8282

8383
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
8484
# If we are profiling, stop the profiler and dump the profile & stats
@@ -103,8 +103,8 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
103103
self._services.performance_statistics.log_stats(queue_item.session.id)
104104
self._services.performance_statistics.reset_stats()
105105

106-
if self.on_after_run_session:
107-
self.on_after_run_session(queue_item)
106+
for callback in self._on_after_run_session_callbacks:
107+
callback(queue_item)
108108

109109
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
110110
"""Run before a node is executed"""
@@ -117,9 +117,9 @@ def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQue
117117
node=invocation.model_dump(),
118118
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
119119
)
120-
# And run lifecycle callbacks
121-
if self.on_before_run_node is not None:
122-
self.on_before_run_node(invocation, queue_item)
120+
121+
for callback in self._on_before_run_node_callbacks:
122+
callback(invocation, queue_item)
123123

124124
def _on_after_run_node(
125125
self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput
@@ -135,9 +135,9 @@ def _on_after_run_node(
135135
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
136136
result=outputs.model_dump(),
137137
)
138-
# And run lifecycle callbacks
139-
if self.on_after_run_node is not None:
140-
self.on_after_run_node(invocation, queue_item, outputs)
138+
139+
for callback in self._on_after_run_node_callbacks:
140+
callback(invocation, queue_item, outputs)
141141

142142
def _on_node_error(
143143
self,
@@ -169,8 +169,8 @@ def _on_node_error(
169169
project_id=None,
170170
)
171171

172-
if self.on_node_error is not None:
173-
self.on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
172+
for callback in self._on_node_error_callbacks:
173+
callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
174174

175175
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
176176
"""Run a single node in the graph"""
@@ -213,6 +213,7 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
213213
# loop go to its next iteration, and the cancel event will be handled correctly.
214214
pass
215215
except Exception as e:
216+
# Must extract the exception traceback here to not lose its stacktrace when we change scope
216217
exc_type = type(e)
217218
exc_value = e
218219
exc_traceback = e.__traceback__
@@ -224,14 +225,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
224225
def __init__(
225226
self,
226227
session_runner: Optional[SessionRunnerBase] = None,
227-
on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None,
228+
on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None,
228229
thread_limit: int = 1,
229230
polling_interval: int = 1,
230231
) -> None:
231232
super().__init__()
232233

233234
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
234-
self.on_non_fatal_processor_error = on_non_fatal_processor_error
235+
self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or []
235236
self._thread_limit = thread_limit
236237
self._polling_interval = polling_interval
237238

@@ -249,8 +250,8 @@ def _on_non_fatal_processor_error(
249250
if queue_item is not None:
250251
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
251252

252-
if self.on_non_fatal_processor_error:
253-
self.on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item)
253+
for callback in self._on_non_fatal_processor_error_callbacks:
254+
callback(exc_type, exc_value, exc_traceback, queue_item)
254255

255256
def start(self, invoker: Invoker) -> None:
256257
self._invoker: Invoker = invoker
@@ -376,6 +377,7 @@ def _process(
376377
self.session_runner.run(queue_item=self._queue_item)
377378

378379
except Exception as e:
380+
# Must extract the exception traceback here to not lose its stacktrace when we change scope
379381
exc_type = type(e)
380382
exc_value = e
381383
exc_traceback = e.__traceback__

0 commit comments

Comments
 (0)