Skip to content

Commit 2dd3a85

Browse files
feat(processor): update enriched errors & fail_queue_item()
1 parent a8492bd commit 2dd3a85

File tree

2 files changed

+83
-53
lines changed

2 files changed

+83
-53
lines changed

invokeai/app/services/session_processor/session_processor_base.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from abc import ABC, abstractmethod
22
from threading import Event
3-
from types import TracebackType
43
from typing import Optional, Protocol
54

65
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
@@ -71,9 +70,9 @@ def __call__(
7170
self,
7271
invocation: BaseInvocation,
7372
queue_item: SessionQueueItem,
74-
exc_type: type,
75-
exc_value: BaseException,
76-
exc_traceback: TracebackType,
73+
error_type: str,
74+
error_message: str,
75+
error_traceback: str,
7776
) -> bool: ...
7877

7978

@@ -88,8 +87,8 @@ def __call__(self, queue_item: SessionQueueItem) -> bool: ...
8887
class OnNonFatalProcessorError(Protocol):
8988
def __call__(
9089
self,
91-
exc_type: type,
92-
exc_value: BaseException,
93-
exc_traceback: TracebackType,
94-
queue_item: Optional[SessionQueueItem] = None,
90+
queue_item: Optional[SessionQueueItem],
91+
error_type: str,
92+
error_message: str,
93+
error_traceback: str,
9594
) -> bool: ...

invokeai/app/services/session_processor/session_processor_default.py

Lines changed: 76 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from contextlib import suppress
33
from threading import BoundedSemaphore, Thread
44
from threading import Event as ThreadEvent
5-
from types import TracebackType
65
from typing import Optional
76

87
from fastapi_events.handlers.local import local_handler
@@ -30,12 +29,6 @@
3029
from .session_processor_common import SessionProcessorStatus
3130

3231

33-
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
34-
"""Formats a stacktrace as a string"""
35-
36-
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
37-
38-
3932
class DefaultSessionRunner(SessionRunnerBase):
4033
"""Processes a single session's invocations"""
4134

@@ -71,10 +64,16 @@ def run(self, queue_item: SessionQueueItem):
7164
invocation = queue_item.session.next()
7265
# Anything other than a `NodeInputError` is handled as a processor error
7366
except NodeInputError as e:
74-
# Must extract the exception traceback here to not lose its stacktrace when we change scope
75-
traceback = e.__traceback__
76-
assert traceback is not None
77-
self._on_node_error(e.node, queue_item, type(e), e, traceback)
67+
error_type = e.__class__.__name__
68+
error_message = str(e)
69+
error_traceback = traceback.format_exc()
70+
self._on_node_error(
71+
invocation=e.node,
72+
queue_item=queue_item,
73+
error_type=error_type,
74+
error_message=error_message,
75+
error_traceback=error_traceback,
76+
)
7877
break
7978

8079
if invocation is None or self._cancel_event.is_set():
@@ -126,10 +125,16 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
126125
# loop go to its next iteration, and the cancel event will be handled correctly.
127126
pass
128127
except Exception as e:
129-
# Must extract the exception traceback here to not lose its stacktrace when we change scope
130-
exc_traceback = e.__traceback__
131-
assert exc_traceback is not None
132-
self._on_node_error(invocation, queue_item, type(e), e, exc_traceback)
128+
error_type = e.__class__.__name__
129+
error_message = str(e)
130+
error_traceback = traceback.format_exc()
131+
self._on_node_error(
132+
invocation=invocation,
133+
queue_item=queue_item,
134+
error_type=error_type,
135+
error_message=error_message,
136+
error_traceback=error_traceback,
137+
)
133138

134139
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
135140
# If profiling is enabled, start the profiler
@@ -166,7 +171,7 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
166171
self._services.performance_statistics.reset_stats()
167172

168173
for callback in self._on_after_run_session_callbacks:
169-
callback(queue_item)
174+
callback(queue_item=queue_item)
170175

171176
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
172177
"""Run before a node is executed"""
@@ -181,7 +186,7 @@ def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQue
181186
)
182187

183188
for callback in self._on_before_run_node_callbacks:
184-
callback(invocation, queue_item)
189+
callback(invocation=invocation, queue_item=queue_item)
185190

186191
def _on_after_run_node(
187192
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
@@ -199,23 +204,23 @@ def _on_after_run_node(
199204
)
200205

201206
for callback in self._on_after_run_node_callbacks:
202-
callback(invocation, queue_item, output)
207+
callback(invocation=invocation, queue_item=queue_item, output=output)
203208

204209
def _on_node_error(
205210
self,
206211
invocation: BaseInvocation,
207212
queue_item: SessionQueueItem,
208-
exc_type: type,
209-
exc_value: BaseException,
210-
exc_traceback: TracebackType,
213+
error_type: str,
214+
error_message: str,
215+
error_traceback: str,
211216
):
212-
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
213-
214-
queue_item.session.set_node_error(invocation.id, stacktrace)
217+
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
218+
node_error = f"{error_type}: {error_message}"
219+
queue_item.session.set_node_error(invocation.id, node_error)
215220
self._services.logger.error(
216-
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {exc_type.__name__}"
221+
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}"
217222
)
218-
self._services.logger.error(stacktrace)
223+
self._services.logger.error(error_traceback)
219224

220225
# Send error event
221226
self._services.events.emit_invocation_error(
@@ -225,14 +230,21 @@ def _on_node_error(
225230
graph_execution_state_id=queue_item.session.id,
226231
node=invocation.model_dump(),
227232
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
228-
error_type=exc_type.__name__,
229-
error=stacktrace,
233+
error_type=error_type,
234+
error_message=error_message,
235+
error_traceback=error_traceback,
230236
user_id=getattr(queue_item, "user_id", None),
231237
project_id=getattr(queue_item, "project_id", None),
232238
)
233239

234240
for callback in self._on_node_error_callbacks:
235-
callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
241+
callback(
242+
invocation=invocation,
243+
queue_item=queue_item,
244+
error_type=error_type,
245+
error_message=error_message,
246+
error_traceback=error_traceback,
247+
)
236248

237249

238250
class DefaultSessionProcessor(SessionProcessorBase):
@@ -374,16 +386,25 @@ def _process(
374386
self.session_runner.run(queue_item=self._queue_item)
375387

376388
except Exception as e:
377-
# Must extract the exception traceback here to not lose its stacktrace when we change scope
378-
exc_traceback = e.__traceback__
379-
assert exc_traceback is not None
380-
self._on_non_fatal_processor_error(self._queue_item, type(e), e, exc_traceback)
381-
# Immediately poll for next queue item
389+
error_type = e.__class__.__name__
390+
error_message = str(e)
391+
error_traceback = traceback.format_exc()
392+
self._on_non_fatal_processor_error(
393+
queue_item=self._queue_item,
394+
error_type=error_type,
395+
error_message=error_message,
396+
error_traceback=error_traceback,
397+
)
398+
# Wait for next polling interval or event to try again
382399
poll_now_event.wait(self._polling_interval)
383400
continue
384-
except Exception:
401+
except Exception as e:
385402
# Fatal error in processor, log and pass - we're done here
386-
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
403+
error_type = e.__class__.__name__
404+
error_message = str(e)
405+
error_traceback = traceback.format_exc()
406+
self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}")
407+
self._invoker.services.logger.error(error_traceback)
387408
pass
388409
finally:
389410
stop_event.clear()
@@ -394,19 +415,29 @@ def _process(
394415
def _on_non_fatal_processor_error(
395416
self,
396417
queue_item: Optional[SessionQueueItem],
397-
exc_type: type,
398-
exc_value: BaseException,
399-
exc_traceback: TracebackType,
418+
error_type: str,
419+
error_message: str,
420+
error_traceback: str,
400421
) -> None:
401-
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
402422
# Non-fatal error in processor
403-
self._invoker.services.logger.error(f"Non-fatal error in session processor: {exc_type.__name__}")
404-
self._invoker.services.logger.error(stacktrace)
423+
self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}")
424+
self._invoker.services.logger.error(error_traceback)
425+
405426
if queue_item is not None:
406427
# Update the queue item with the completed session
407428
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
408-
# And cancel the queue item with an error
409-
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
429+
# Fail the queue item
430+
self._invoker.services.session_queue.fail_queue_item(
431+
item_id=queue_item.item_id,
432+
error_type=error_type,
433+
error_message=error_message,
434+
error_traceback=error_traceback,
435+
)
410436

411437
for callback in self._on_non_fatal_processor_error_callbacks:
412-
callback(exc_type, exc_value, exc_traceback, queue_item)
438+
callback(
439+
queue_item=queue_item,
440+
error_type=error_type,
441+
error_message=error_message,
442+
error_traceback=error_traceback,
443+
)

0 commit comments

Comments
 (0)