Skip to content

Commit 25954ea

Browse files
feat(queue): session queue error handling
- Add handling for new error columns `error_type`, `error_message`, `error_traceback`. - Update queue item model to include the new data. The `error_traceback` field has an alias of `error` for backwards compatibility. - Add `fail_queue_item` method. This was previously handled by `cancel_queue_item`. Splitting this functionality makes failing a queue item a bit more explicit. We also don't need to handle multiple optional error args. -
1 parent 887b73a commit 25954ea

File tree

3 files changed

+70
-12
lines changed

3 files changed

+70
-12
lines changed

invokeai/app/services/session_queue/session_queue_base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,15 @@ def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
7474
pass
7575

7676
@abstractmethod
77-
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
77+
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
7878
"""Cancels a session queue item"""
7979
pass
8080

81+
@abstractmethod
82+
def fail_queue_item(self, item_id: int, error_type: str, error_message: str, error_traceback: str) -> SessionQueueItem:
83+
"""Fails a session queue item"""
84+
pass
85+
8186
@abstractmethod
8287
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
8388
"""Cancels all queue items with matching batch IDs"""

invokeai/app/services/session_queue/session_queue_common.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@
33
from itertools import chain, product
44
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
55

6-
from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator
6+
from pydantic import (
7+
AliasChoices,
8+
BaseModel,
9+
ConfigDict,
10+
Field,
11+
StrictStr,
12+
TypeAdapter,
13+
field_validator,
14+
model_validator,
15+
)
716
from pydantic_core import to_jsonable_python
817

918
from invokeai.app.invocations.baseinvocation import BaseInvocation
@@ -189,7 +198,13 @@ class SessionQueueItemWithoutGraph(BaseModel):
189198
session_id: str = Field(
190199
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
191200
)
192-
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
201+
error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored")
202+
error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored")
203+
error_traceback: Optional[str] = Field(
204+
default=None,
205+
description="The error traceback if this queue item errored",
206+
validation_alias=AliasChoices("error_traceback", "error"),
207+
)
193208
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
194209
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
195210
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")

invokeai/app/services/session_queue/session_queue_sqlite.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,18 @@ async def _handle_complete_event(self, event: FastAPIEvent) -> None:
8282
async def _handle_error_event(self, event: FastAPIEvent) -> None:
8383
try:
8484
item_id = event[1]["data"]["queue_item_id"]
85-
error = event[1]["data"]["error"]
85+
error_type = event[1]["data"]["error_type"]
86+
error_message = event[1]["data"]["error_message"]
87+
error_traceback = event[1]["data"]["error_traceback"]
8688
queue_item = self.get_queue_item(item_id)
8789
# always set to failed if have an error, even if previously the item was marked completed or canceled
88-
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
90+
queue_item = self._set_queue_item_status(
91+
item_id=queue_item.item_id,
92+
status="failed",
93+
error_type=error_type,
94+
error_message=error_message,
95+
error_traceback=error_traceback,
96+
)
8997
except SessionQueueItemNotFoundError:
9098
return
9199

@@ -272,17 +280,22 @@ def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
272280
return SessionQueueItem.queue_item_from_dict(dict(result))
273281

274282
def _set_queue_item_status(
275-
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
283+
self,
284+
item_id: int,
285+
status: QUEUE_ITEM_STATUS,
286+
error_type: Optional[str] = None,
287+
error_message: Optional[str] = None,
288+
error_traceback: Optional[str] = None,
276289
) -> SessionQueueItem:
277290
try:
278291
self.__lock.acquire()
279292
self.__cursor.execute(
280293
"""--sql
281294
UPDATE session_queue
282-
SET status = ?, error = ?
295+
SET status = ?, error_type = ?, error_message = ?, error_traceback = ?
283296
WHERE item_id = ?
284297
""",
285-
(status, error, item_id),
298+
(status, error_type, error_message, error_traceback, item_id),
286299
)
287300
self.__conn.commit()
288301
except Exception:
@@ -425,11 +438,34 @@ def prune(self, queue_id: str) -> PruneResult:
425438
self.__lock.release()
426439
return PruneResult(deleted=count)
427440

428-
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
441+
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
429442
queue_item = self.get_queue_item(item_id)
430443
if queue_item.status not in ["canceled", "failed", "completed"]:
431-
status = "failed" if error is not None else "canceled"
432-
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
444+
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
445+
self.__invoker.services.events.emit_session_canceled(
446+
queue_item_id=queue_item.item_id,
447+
queue_id=queue_item.queue_id,
448+
queue_batch_id=queue_item.batch_id,
449+
graph_execution_state_id=queue_item.session_id,
450+
)
451+
return queue_item
452+
453+
def fail_queue_item(
454+
self,
455+
item_id: int,
456+
error_type: str,
457+
error_message: str,
458+
error_traceback: str,
459+
) -> SessionQueueItem:
460+
queue_item = self.get_queue_item(item_id)
461+
if queue_item.status not in ["canceled", "failed", "completed"]:
462+
queue_item = self._set_queue_item_status(
463+
item_id=item_id,
464+
status="failed",
465+
error_type=error_type,
466+
error_message=error_message,
467+
error_traceback=error_traceback,
468+
)
433469
self.__invoker.services.events.emit_session_canceled(
434470
queue_item_id=queue_item.item_id,
435471
queue_id=queue_item.queue_id,
@@ -602,7 +638,9 @@ def list_queue_items(
602638
status,
603639
priority,
604640
field_values,
605-
error,
641+
error_type,
642+
error_message,
643+
error_traceback,
606644
created_at,
607645
updated_at,
608646
completed_at,

0 commit comments

Comments
 (0)