Skip to content

Commit eff3596

Browse files
tidy(app): rearrange proccessor
1 parent cef1585 commit eff3596

File tree

1 file changed

+68
-66
lines changed

1 file changed

+68
-66
lines changed

invokeai/app/services/session_processor/session_processor_default.py

Lines changed: 68 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131

3232
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
33+
"""Formats a stacktrace as a string"""
34+
3335
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
3436

3537

@@ -72,6 +74,54 @@ def run(self, queue_item: SessionQueueItem):
7274

7375
self._on_after_run_session(queue_item=queue_item)
7476

77+
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
78+
"""Run a single node in the graph"""
79+
try:
80+
# Any unhandled exception is an invocation error & will fail the graph
81+
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
82+
self._on_before_run_node(invocation, queue_item)
83+
84+
data = InvocationContextData(
85+
invocation=invocation,
86+
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
87+
queue_item=queue_item,
88+
)
89+
context = build_invocation_context(
90+
data=data,
91+
services=self._services,
92+
cancel_event=self._cancel_event,
93+
)
94+
95+
# Invoke the node
96+
outputs = invocation.invoke_internal(context=context, services=self._services)
97+
# Save outputs and history
98+
queue_item.session.complete(invocation.id, outputs)
99+
100+
self._on_after_run_node(invocation, queue_item, outputs)
101+
102+
except KeyboardInterrupt:
103+
# TODO(MM2): Create an event for this
104+
pass
105+
except CanceledException:
106+
# When the user cancels the graph, we first set the cancel event. The event is checked
107+
# between invocations, in this loop. Some invocations are long-running, and we need to
108+
# be able to cancel them mid-execution.
109+
#
110+
# For example, denoising is a long-running invocation with many steps. A step callback
111+
# is executed after each step. This step callback checks if the canceled event is set,
112+
# then raises a CanceledException to stop execution immediately.
113+
#
114+
# When we get a CanceledException, we don't need to do anything - just pass and let the
115+
# loop go to its next iteration, and the cancel event will be handled correctly.
116+
pass
117+
except Exception as e:
118+
# Must extract the exception traceback here to not lose its stacktrace when we change scope
119+
exc_type = type(e)
120+
exc_value = e
121+
exc_traceback = e.__traceback__
122+
assert exc_traceback is not None
123+
self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
124+
75125
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
76126
# If profiling is enabled, start the profiler
77127
if self._profiler is not None:
@@ -172,54 +222,6 @@ def _on_node_error(
172222
for callback in self._on_node_error_callbacks:
173223
callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
174224

175-
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
176-
"""Run a single node in the graph"""
177-
try:
178-
# Any unhandled exception is an invocation error & will fail the graph
179-
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
180-
self._on_before_run_node(invocation, queue_item)
181-
182-
data = InvocationContextData(
183-
invocation=invocation,
184-
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
185-
queue_item=queue_item,
186-
)
187-
context = build_invocation_context(
188-
data=data,
189-
services=self._services,
190-
cancel_event=self._cancel_event,
191-
)
192-
193-
# Invoke the node
194-
outputs = invocation.invoke_internal(context=context, services=self._services)
195-
# Save outputs and history
196-
queue_item.session.complete(invocation.id, outputs)
197-
198-
self._on_after_run_node(invocation, queue_item, outputs)
199-
200-
except KeyboardInterrupt:
201-
# TODO(MM2): Create an event for this
202-
pass
203-
except CanceledException:
204-
# When the user cancels the graph, we first set the cancel event. The event is checked
205-
# between invocations, in this loop. Some invocations are long-running, and we need to
206-
# be able to cancel them mid-execution.
207-
#
208-
# For example, denoising is a long-running invocation with many steps. A step callback
209-
# is executed after each step. This step callback checks if the canceled event is set,
210-
# then raises a CanceledException to stop execution immediately.
211-
#
212-
# When we get a CanceledException, we don't need to do anything - just pass and let the
213-
# loop go to its next iteration, and the cancel event will be handled correctly.
214-
pass
215-
except Exception as e:
216-
# Must extract the exception traceback here to not lose its stacktrace when we change scope
217-
exc_type = type(e)
218-
exc_value = e
219-
exc_traceback = e.__traceback__
220-
assert exc_traceback is not None
221-
self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
222-
223225

224226
class DefaultSessionProcessor(SessionProcessorBase):
225227
def __init__(
@@ -236,24 +238,6 @@ def __init__(
236238
self._thread_limit = thread_limit
237239
self._polling_interval = polling_interval
238240

239-
def _on_non_fatal_processor_error(
240-
self,
241-
queue_item: Optional[SessionQueueItem],
242-
exc_type: type,
243-
exc_value: BaseException,
244-
exc_traceback: TracebackType,
245-
) -> None:
246-
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
247-
# Non-fatal error in processor
248-
self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}")
249-
# Cancel the queue item
250-
if queue_item is not None:
251-
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
252-
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
253-
254-
for callback in self._on_non_fatal_processor_error_callbacks:
255-
callback(exc_type, exc_value, exc_traceback, queue_item)
256-
257241
def start(self, invoker: Invoker) -> None:
258242
self._invoker: Invoker = invoker
259243
self._queue_item: Optional[SessionQueueItem] = None
@@ -396,3 +380,21 @@ def _process(
396380
poll_now_event.clear()
397381
self._queue_item = None
398382
self._thread_semaphore.release()
383+
384+
def _on_non_fatal_processor_error(
385+
self,
386+
queue_item: Optional[SessionQueueItem],
387+
exc_type: type,
388+
exc_value: BaseException,
389+
exc_traceback: TracebackType,
390+
) -> None:
391+
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
392+
# Non-fatal error in processor
393+
self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}")
394+
# Cancel the queue item
395+
if queue_item is not None:
396+
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
397+
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
398+
399+
for callback in self._on_non_fatal_processor_error_callbacks:
400+
callback(exc_type, exc_value, exc_traceback, queue_item)

0 commit comments

Comments
 (0)