From e51a3025eae7a2e7bb9edea0ad46473f5121a699 Mon Sep 17 00:00:00 2001 From: brandonrising Date: Thu, 16 May 2024 13:30:04 -0400 Subject: [PATCH 01/34] Break apart session processor and the running of each session into separate classes --- .../session_processor_base.py | 28 ++ .../session_processor_default.py | 341 +++++++++++------- 2 files changed, 237 insertions(+), 132 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_base.py b/invokeai/app/services/session_processor/session_processor_base.py index 485ef2f8c38..745430d201f 100644 --- a/invokeai/app/services/session_processor/session_processor_base.py +++ b/invokeai/app/services/session_processor/session_processor_base.py @@ -1,6 +1,34 @@ from abc import ABC, abstractmethod +from threading import Event +from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus +from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem + +class SessionRunnerBase(ABC): + """ + Base class for session runner. + """ + + @abstractmethod + def start(self, services: InvocationServices, cancel_event: Event) -> None: + """Starts the session runner""" + pass + + @abstractmethod + def run(self, queue_item: SessionQueueItem) -> None: + """Runs the session""" + pass + + @abstractmethod + def complete(self, queue_item: SessionQueueItem) -> None: + """Completes the session""" + pass + + @abstractmethod + def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None: + """Runs an already prepared node on the session""" + pass class SessionProcessorBase(ABC): diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 2a0ebc31680..eeb8df4cadd 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -2,7 +2,7 @@ from contextlib import suppress from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent -from typing import Optional +from typing import Callable, Optional, Union from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event as FastAPIEvent @@ -16,15 +16,207 @@ from invokeai.app.util.profiler import Profiler from ..invoker import Invoker -from .session_processor_base import SessionProcessorBase +from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase from .session_processor_common import SessionProcessorStatus +class DefaultSessionRunner(SessionRunnerBase): + """Processes a single session's invocations""" + + def __init__( + self, + on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, + on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, + ): + self.on_before_run_node = on_before_run_node + self.on_after_run_node = on_after_run_node + + def start(self, services: InvocationServices, cancel_event: ThreadEvent): + """Start the session runner""" + self.services = services + self.cancel_event = cancel_event + + def next_invocation( + self, previous_invocation: Optional[BaseInvocation], queue_item: SessionQueueItem, cancel_event: ThreadEvent + ) -> Optional[BaseInvocation]: + invocation = None + if not (queue_item.session.is_complete() or cancel_event.is_set()): + try: + invocation = queue_item.session.next() + except Exception as exc: + self.services.logger.error("ERROR: %s" % exc, exc_info=True) + + node_error = str(exc) + + # Save error + if previous_invocation is not None: + queue_item.session.set_node_error(previous_invocation.id, node_error) + + # Send error event + self.services.events.emit_invocation_error( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + node=previous_invocation.model_dump() if previous_invocation else {}, + source_node_id=queue_item.session.prepared_source_mapping[previous_invocation.id] + if previous_invocation + else "", + error_type=exc.__class__.__name__, + error=node_error, + user_id=None, + project_id=None, + ) + + if queue_item.session.is_complete() or cancel_event.is_set(): + # Set the invocation to None to prepare for the next session + invocation = None + return invocation + + def run(self, queue_item: SessionQueueItem): + """Run the graph""" + if not queue_item.session: + raise ValueError("Queue item has no session") + invocation = None + # Loop over invocations until the session is complete or canceled + while self.next_invocation(invocation, queue_item, self.cancel_event) and not self.cancel_event.is_set(): + # Prepare the next node + invocation = queue_item.session.next() + if invocation is None: + # If there are no more invocations, complete the graph + break + # Build invocation context (the node-facing API + self.run_node(invocation.id, queue_item) + self.complete(queue_item) + + def complete(self, queue_item: SessionQueueItem): + # Send complete event + self.services.events.emit_graph_execution_complete( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + ) + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor + # we don't care about that - suppress the error. + with suppress(GESStatsNotFoundError): + self.services.performance_statistics.log_stats(queue_item.session.id) + self.services.performance_statistics.reset_stats() + + def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + """Run before a node is executed""" + # Send starting event + self.services.events.emit_invocation_started( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session_id, + node=invocation.model_dump(), + source_node_id=queue_item.session.prepared_source_mapping[invocation.id], + ) + if self.on_before_run_node is not None: + self.on_before_run_node(invocation, queue_item) + + def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + """Run after a node is executed""" + if self.on_after_run_node is not None: + self.on_after_run_node(invocation, queue_item) + + def run_node(self, node_id: str, queue_item: SessionQueueItem): + """Run a single node in the graph""" + # If this error raises a NodeNotFoundError that's handled by the processor + invocation = queue_item.session.execution_graph.get_node(node_id) + try: + self._on_before_run_node(invocation, queue_item) + data = InvocationContextData( + invocation=invocation, + source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], + queue_item=queue_item, + ) + + # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph + with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id): + context = build_invocation_context( + data=data, + services=self.services, + cancel_event=self.cancel_event, + ) + + # Invoke the node + outputs = invocation.invoke_internal(context=context, services=self.services) + + # Save outputs and history + queue_item.session.complete(invocation.id, outputs) + + self._on_after_run_node(invocation, queue_item) + # Send complete event on successful runs + self.services.events.emit_invocation_complete( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + node=invocation.model_dump(), + source_node_id=data.source_invocation_id, + result=outputs.model_dump(), + ) + except KeyboardInterrupt: + # TODO(MM2): Create an event for this + pass + except CanceledException: + # When the user cancels the graph, we first set the cancel event. The event is checked + # between invocations, in this loop. Some invocations are long-running, and we need to + # be able to cancel them mid-execution. + # + # For example, denoising is a long-running invocation with many steps. A step callback + # is executed after each step. This step callback checks if the canceled event is set, + # then raises a CanceledException to stop execution immediately. + # + # When we get a CanceledException, we don't need to do anything - just pass and let the + # loop go to its next iteration, and the cancel event will be handled correctly. + pass + except Exception as e: + error = traceback.format_exc() + + # Save error + queue_item.session.set_node_error(invocation.id, error) + self.services.logger.error( + f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}" + ) + self.services.logger.error(error) + + # Send error event + self.services.events.emit_invocation_error( + queue_batch_id=queue_item.session_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + node=invocation.model_dump(), + source_node_id=queue_item.session.prepared_source_mapping[invocation.id], + error_type=e.__class__.__name__, + error=error, + user_id=None, + project_id=None, + ) + + class DefaultSessionProcessor(SessionProcessorBase): - def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None: + def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None: + super().__init__() + self.session_runner = session_runner if session_runner else DefaultSessionRunner() + + def start( + self, + invoker: Invoker, + thread_limit: int = 1, + polling_interval: int = 1, + on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None, + on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None, + ) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None self._invocation: Optional[BaseInvocation] = None + self.on_before_run_session = on_before_run_session + self.on_after_run_session = on_after_run_session self._resume_event = ThreadEvent() self._stop_event = ThreadEvent() @@ -49,6 +241,7 @@ def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = else None ) + self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event) self._thread = Thread( name="session_processor", target=self._process, @@ -142,141 +335,25 @@ def _process( self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") cancel_event.clear() + # If we have a on_before_run_session callback, call it + if self.on_before_run_session is not None: + self.on_before_run_session(self._queue_item) + # If profiling is enabled, start the profiler if self._profiler is not None: self._profiler.start(profile_id=self._queue_item.session_id) - # Prepare invocations and take the first - self._invocation = self._queue_item.session.next() - - # Loop over invocations until the session is complete or canceled - while self._invocation is not None and not cancel_event.is_set(): - # get the source node id to provide to clients (the prepared node id is not as useful) - source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id] - - # Send starting event - self._invoker.services.events.emit_invocation_started( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session_id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, + # Run the graph + self.session_runner.run(queue_item=self._queue_item) + + # If we are profiling, stop the profiler and dump the profile & stats + if self._profiler: + profile_path = self._profiler.stop() + stats_path = profile_path.with_suffix(".json") + self._invoker.services.performance_statistics.dump_stats( + graph_execution_state_id=self._queue_item.session.id, output_path=stats_path ) - # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph - try: - with self._invoker.services.performance_statistics.collect_stats( - self._invocation, self._queue_item.session.id - ): - # Build invocation context (the node-facing API) - data = InvocationContextData( - invocation=self._invocation, - source_invocation_id=source_invocation_id, - queue_item=self._queue_item, - ) - context = build_invocation_context( - data=data, - services=self._invoker.services, - cancel_event=self._cancel_event, - ) - - # Invoke the node - outputs = self._invocation.invoke_internal( - context=context, services=self._invoker.services - ) - - # Save outputs and history - self._queue_item.session.complete(self._invocation.id, outputs) - - # Send complete event - self._invoker.services.events.emit_invocation_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - result=outputs.model_dump(), - ) - - except KeyboardInterrupt: - # TODO(MM2): Create an event for this - pass - - except CanceledException: - # When the user cancels the graph, we first set the cancel event. The event is checked - # between invocations, in this loop. Some invocations are long-running, and we need to - # be able to cancel them mid-execution. - # - # For example, denoising is a long-running invocation with many steps. A step callback - # is executed after each step. This step callback checks if the canceled event is set, - # then raises a CanceledException to stop execution immediately. - # - # When we get a CanceledException, we don't need to do anything - just pass and let the - # loop go to its next iteration, and the cancel event will be handled correctly. - pass - - except Exception as e: - error = traceback.format_exc() - - # Save error - self._queue_item.session.set_node_error(self._invocation.id, error) - self._invoker.services.logger.error( - f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}" - ) - self._invoker.services.logger.error(error) - - # Send error event - self._invoker.services.events.emit_invocation_error( - queue_batch_id=self._queue_item.session_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - error_type=e.__class__.__name__, - error=error, - user_id=None, - project_id=None, - ) - pass - - # The session is complete if the all invocations are complete or there was an error - if self._queue_item.session.is_complete() or cancel_event.is_set(): - # Send complete event - self._invoker.services.session_queue.set_queue_item_session( - self._queue_item.item_id, self._queue_item.session - ) - self._invoker.services.events.emit_graph_execution_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - ) - # If we are profiling, stop the profiler and dump the profile & stats - if self._profiler: - profile_path = self._profiler.stop() - stats_path = profile_path.with_suffix(".json") - self._invoker.services.performance_statistics.dump_stats( - graph_execution_state_id=self._queue_item.session.id, output_path=stats_path - ) - # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor - # we don't care about that - suppress the error. - with suppress(GESStatsNotFoundError): - self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id) - self._invoker.services.performance_statistics.reset_stats() - - # Set the invocation to None to prepare for the next session - self._invocation = None - else: - # Prepare the next invocation - self._invocation = self._queue_item.session.next() - else: - # The queue was empty, wait for next polling interval or event to try again - self._invoker.services.logger.debug("Waiting for next polling interval or event") - poll_now_event.wait(self._polling_interval) - continue except Exception: # Non-fatal error in processor self._invoker.services.logger.error( From 82957bb826a4662bdb8b182970aa8840f7445a6c Mon Sep 17 00:00:00 2001 From: brandonrising Date: Thu, 16 May 2024 13:35:25 -0400 Subject: [PATCH 02/34] Run ruff --- .../app/services/session_processor/session_processor_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/app/services/session_processor/session_processor_base.py b/invokeai/app/services/session_processor/session_processor_base.py index 745430d201f..7a67c3ab2c0 100644 --- a/invokeai/app/services/session_processor/session_processor_base.py +++ b/invokeai/app/services/session_processor/session_processor_base.py @@ -5,6 +5,7 @@ from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem + class SessionRunnerBase(ABC): """ Base class for session runner. From 8edc25d35a10c1426a68273ebbb0665701e45aa5 Mon Sep 17 00:00:00 2001 From: brandonrising Date: Thu, 16 May 2024 13:47:05 -0400 Subject: [PATCH 03/34] Fix next node calling logic --- .../session_processor/session_processor_default.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index eeb8df4cadd..7f5a107b83a 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -79,14 +79,10 @@ def run(self, queue_item: SessionQueueItem): raise ValueError("Queue item has no session") invocation = None # Loop over invocations until the session is complete or canceled - while self.next_invocation(invocation, queue_item, self.cancel_event) and not self.cancel_event.is_set(): - # Prepare the next node - invocation = queue_item.session.next() - if invocation is None: - # If there are no more invocations, complete the graph - break - # Build invocation context (the node-facing API + invocation = self.next_invocation(invocation, queue_item, self.cancel_event) + while invocation is not None and not self.cancel_event.is_set(): self.run_node(invocation.id, queue_item) + invocation = self.next_invocation(invocation, queue_item, self.cancel_event) self.complete(queue_item) def complete(self, queue_item: SessionQueueItem): From efb069dd71476c00a8ce31cd739c4ad461cf6b07 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 22 May 2024 18:33:12 +1000 Subject: [PATCH 04/34] feat(app): iterate on processor split - Add `OnNodeError` and `OnNonFatalProcessorError` callbacks - Move all session/node callbacks to `SessionRunner` - this ensures we dump perf stats before resetting them and generally makes sense to me - Remove `complete` event from `SessionRunner`, it's essentially the same as `OnAfterRunSession` - Remove extraneous `next_invocation` block, which would treat a processor error as a node error - Simplify loops - Add some callbacks for testing, to be removed before merge --- invokeai/app/api/dependencies.py | 38 ++- .../session_processor_base.py | 8 +- .../session_processor_default.py | 284 +++++++++--------- 3 files changed, 188 insertions(+), 142 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 9a6c7416f69..1ff45be866f 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -29,7 +29,7 @@ from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService -from ..services.session_processor.session_processor_default import DefaultSessionProcessor +from ..services.session_processor.session_processor_default import DefaultSessionProcessor, DefaultSessionRunner from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.urls.urls_default import LocalUrlService from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage @@ -103,7 +103,41 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger ) names = SimpleNameService() performance_statistics = InvocationStatsService() - session_processor = DefaultSessionProcessor() + + def on_before_run_session(queue_item): + print("BEFORE RUN SESSION", queue_item.item_id) + return True + + def on_before_run_node(invocation, queue_item): + print("BEFORE RUN NODE", invocation.id) + return True + + def on_after_run_node(invocation, queue_item, outputs): + print("AFTER RUN NODE", invocation.id) + return True + + def on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback): + print("NODE ERROR", invocation.id) + return True + + def on_after_run_session(queue_item): + print("AFTER RUN SESSION", queue_item.item_id) + return True + + def on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback): + print("NON FATAL PROCESSOR ERROR", exc_value) + return True + + session_processor = DefaultSessionProcessor( + DefaultSessionRunner( + on_before_run_session, + on_before_run_node, + on_after_run_node, + on_node_error, + on_after_run_session, + ), + on_non_fatal_processor_error, + ) session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() workflow_records = SqliteWorkflowRecordsStorage(db=db) diff --git a/invokeai/app/services/session_processor/session_processor_base.py b/invokeai/app/services/session_processor/session_processor_base.py index 7a67c3ab2c0..71408475188 100644 --- a/invokeai/app/services/session_processor/session_processor_base.py +++ b/invokeai/app/services/session_processor/session_processor_base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from threading import Event +from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem @@ -22,12 +23,7 @@ def run(self, queue_item: SessionQueueItem) -> None: pass @abstractmethod - def complete(self, queue_item: SessionQueueItem) -> None: - """Completes the session""" - pass - - @abstractmethod - def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None: + def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None: """Runs an already prepared node on the session""" pass diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 7f5a107b83a..33274dd97b1 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -2,12 +2,13 @@ from contextlib import suppress from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent -from typing import Callable, Optional, Union +from types import TracebackType +from typing import Callable, Optional, TypeAlias from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event as FastAPIEvent -from invokeai.app.invocations.baseinvocation import BaseInvocation +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError from invokeai.app.services.session_processor.session_processor_common import CanceledException @@ -19,73 +20,71 @@ from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase from .session_processor_common import SessionProcessorStatus +OnBeforeRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem], bool] +OnAfterRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, BaseInvocationOutput], bool] +OnNodeError: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, type, BaseException, TracebackType], bool] +OnBeforeRunSession: TypeAlias = Callable[[SessionQueueItem], bool] +OnAfterRunSession: TypeAlias = Callable[[SessionQueueItem], bool] +OnNonFatalProcessorError: TypeAlias = Callable[[Optional[SessionQueueItem], type, BaseException, TracebackType], bool] + + +def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str: + return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) + class DefaultSessionRunner(SessionRunnerBase): """Processes a single session's invocations""" def __init__( self, - on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, - on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, + on_before_run_session: Optional[OnBeforeRunSession] = None, + on_before_run_node: Optional[OnBeforeRunNode] = None, + on_after_run_node: Optional[OnAfterRunNode] = None, + on_node_error: Optional[OnNodeError] = None, + on_after_run_session: Optional[OnAfterRunSession] = None, ): + self.on_before_run_session = on_before_run_session self.on_before_run_node = on_before_run_node self.on_after_run_node = on_after_run_node + self.on_node_error = on_node_error + self.on_after_run_session = on_after_run_session def start(self, services: InvocationServices, cancel_event: ThreadEvent): """Start the session runner""" self.services = services self.cancel_event = cancel_event - def next_invocation( - self, previous_invocation: Optional[BaseInvocation], queue_item: SessionQueueItem, cancel_event: ThreadEvent - ) -> Optional[BaseInvocation]: - invocation = None - if not (queue_item.session.is_complete() or cancel_event.is_set()): - try: - invocation = queue_item.session.next() - except Exception as exc: - self.services.logger.error("ERROR: %s" % exc, exc_info=True) - - node_error = str(exc) - - # Save error - if previous_invocation is not None: - queue_item.session.set_node_error(previous_invocation.id, node_error) - - # Send error event - self.services.events.emit_invocation_error( - queue_batch_id=queue_item.batch_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session.id, - node=previous_invocation.model_dump() if previous_invocation else {}, - source_node_id=queue_item.session.prepared_source_mapping[previous_invocation.id] - if previous_invocation - else "", - error_type=exc.__class__.__name__, - error=node_error, - user_id=None, - project_id=None, - ) - - if queue_item.session.is_complete() or cancel_event.is_set(): - # Set the invocation to None to prepare for the next session - invocation = None - return invocation - - def run(self, queue_item: SessionQueueItem): + def run(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None): """Run the graph""" - if not queue_item.session: - raise ValueError("Queue item has no session") - invocation = None # Loop over invocations until the session is complete or canceled - invocation = self.next_invocation(invocation, queue_item, self.cancel_event) - while invocation is not None and not self.cancel_event.is_set(): - self.run_node(invocation.id, queue_item) - invocation = self.next_invocation(invocation, queue_item, self.cancel_event) - self.complete(queue_item) - def complete(self, queue_item: SessionQueueItem): + self._on_before_run_session(queue_item=queue_item) + while True: + invocation = queue_item.session.next() + if invocation is None or self.cancel_event.is_set(): + break + self.run_node(invocation, queue_item) + if queue_item.session.is_complete() or self.cancel_event.is_set(): + break + self._on_after_run_session(queue_item=queue_item) + + def _on_before_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None: + # If profiling is enabled, start the profiler + if profiler is not None: + profiler.start(profile_id=queue_item.session_id) + + if self.on_before_run_session: + self.on_before_run_session(queue_item) + + def _on_after_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None: + # If we are profiling, stop the profiler and dump the profile & stats + if profiler: + profile_path = profiler.stop() + stats_path = profile_path.with_suffix(".json") + self.services.performance_statistics.dump_stats( + graph_execution_state_id=queue_item.session.id, output_path=stats_path + ) + # Send complete event self.services.events.emit_graph_execution_complete( queue_batch_id=queue_item.batch_id, @@ -93,12 +92,16 @@ def complete(self, queue_item: SessionQueueItem): queue_id=queue_item.queue_id, graph_execution_state_id=queue_item.session.id, ) + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor # we don't care about that - suppress the error. with suppress(GESStatsNotFoundError): self.services.performance_statistics.log_stats(queue_item.session.id) self.services.performance_statistics.reset_stats() + if self.on_after_run_session: + self.on_after_run_session(queue_item) + def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run before a node is executed""" # Send starting event @@ -110,28 +113,73 @@ def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQue node=invocation.model_dump(), source_node_id=queue_item.session.prepared_source_mapping[invocation.id], ) + # And run lifecycle callbacks if self.on_before_run_node is not None: self.on_before_run_node(invocation, queue_item) - def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + def _on_after_run_node( + self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput + ): """Run after a node is executed""" + # Send complete event on successful runs + self.services.events.emit_invocation_complete( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + node=invocation.model_dump(), + source_node_id=queue_item.session.prepared_source_mapping[invocation.id], + result=outputs.model_dump(), + ) + # And run lifecycle callbacks if self.on_after_run_node is not None: - self.on_after_run_node(invocation, queue_item) + self.on_after_run_node(invocation, queue_item, outputs) + + def _on_node_error( + self, + invocation: BaseInvocation, + queue_item: SessionQueueItem, + exc_type: type, + exc_value: BaseException, + exc_traceback: TracebackType, + ): + stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) + + queue_item.session.set_node_error(invocation.id, stacktrace) + self.services.logger.error( + f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{exc_type}" + ) + self.services.logger.error(stacktrace) - def run_node(self, node_id: str, queue_item: SessionQueueItem): + # Send error event + self.services.events.emit_invocation_error( + queue_batch_id=queue_item.session_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + node=invocation.model_dump(), + source_node_id=queue_item.session.prepared_source_mapping[invocation.id], + error_type=exc_type.__name__, + error=stacktrace, + user_id=None, + project_id=None, + ) + + if self.on_node_error is not None: + self.on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback) + + def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run a single node in the graph""" - # If this error raises a NodeNotFoundError that's handled by the processor - invocation = queue_item.session.execution_graph.get_node(node_id) try: - self._on_before_run_node(invocation, queue_item) - data = InvocationContextData( - invocation=invocation, - source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], - queue_item=queue_item, - ) - - # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph + # Any unhandled exception is an invocation error & will fail the graph with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id): + self._on_before_run_node(invocation, queue_item) + + data = InvocationContextData( + invocation=invocation, + source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], + queue_item=queue_item, + ) context = build_invocation_context( data=data, services=self.services, @@ -140,21 +188,11 @@ def run_node(self, node_id: str, queue_item: SessionQueueItem): # Invoke the node outputs = invocation.invoke_internal(context=context, services=self.services) - # Save outputs and history queue_item.session.complete(invocation.id, outputs) - self._on_after_run_node(invocation, queue_item) - # Send complete event on successful runs - self.services.events.emit_invocation_complete( - queue_batch_id=queue_item.batch_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session.id, - node=invocation.model_dump(), - source_node_id=data.source_invocation_id, - result=outputs.model_dump(), - ) + self._on_after_run_node(invocation, queue_item, outputs) + except KeyboardInterrupt: # TODO(MM2): Create an event for this pass @@ -171,48 +209,51 @@ def run_node(self, node_id: str, queue_item: SessionQueueItem): # loop go to its next iteration, and the cancel event will be handled correctly. pass except Exception as e: - error = traceback.format_exc() - - # Save error - queue_item.session.set_node_error(invocation.id, error) - self.services.logger.error( - f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}" - ) - self.services.logger.error(error) - - # Send error event - self.services.events.emit_invocation_error( - queue_batch_id=queue_item.session_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session.id, - node=invocation.model_dump(), - source_node_id=queue_item.session.prepared_source_mapping[invocation.id], - error_type=e.__class__.__name__, - error=error, - user_id=None, - project_id=None, - ) + exc_type = type(e) + exc_value = e + exc_traceback = e.__traceback__ + assert exc_traceback is not None + self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback) class DefaultSessionProcessor(SessionProcessorBase): - def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None: + def __init__( + self, + session_runner: Optional[SessionRunnerBase] = None, + on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None, + ) -> None: super().__init__() + self.session_runner = session_runner if session_runner else DefaultSessionRunner() + self.on_non_fatal_processor_error = on_non_fatal_processor_error + + def _on_non_fatal_processor_error( + self, + queue_item: Optional[SessionQueueItem], + exc_type: type, + exc_value: BaseException, + exc_traceback: TracebackType, + ) -> None: + stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) + # Non-fatal error in processor + self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}") + # Cancel the queue item + if queue_item is not None: + self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) + self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) + + if self.on_non_fatal_processor_error: + self.on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback) def start( self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1, - on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None, - on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None, ) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None self._invocation: Optional[BaseInvocation] = None - self.on_before_run_session = on_before_run_session - self.on_after_run_session = on_after_run_session self._resume_event = ThreadEvent() self._stop_event = ThreadEvent() @@ -331,40 +372,15 @@ def _process( self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") cancel_event.clear() - # If we have a on_before_run_session callback, call it - if self.on_before_run_session is not None: - self.on_before_run_session(self._queue_item) - - # If profiling is enabled, start the profiler - if self._profiler is not None: - self._profiler.start(profile_id=self._queue_item.session_id) - # Run the graph self.session_runner.run(queue_item=self._queue_item) - # If we are profiling, stop the profiler and dump the profile & stats - if self._profiler: - profile_path = self._profiler.stop() - stats_path = profile_path.with_suffix(".json") - self._invoker.services.performance_statistics.dump_stats( - graph_execution_state_id=self._queue_item.session.id, output_path=stats_path - ) - - except Exception: - # Non-fatal error in processor - self._invoker.services.logger.error( - f"Non-fatal error in session processor:\n{traceback.format_exc()}" - ) - # Cancel the queue item - if self._queue_item is not None: - self._invoker.services.session_queue.set_queue_item_session( - self._queue_item.item_id, self._queue_item.session - ) - self._invoker.services.session_queue.cancel_queue_item( - self._queue_item.item_id, error=traceback.format_exc() - ) - # Reset the invocation to None to prepare for the next session - self._invocation = None + except Exception as e: + exc_type = type(e) + exc_value = e + exc_traceback = e.__traceback__ + assert exc_traceback is not None + self._on_non_fatal_processor_error(self._queue_item, exc_type, exc_value, exc_traceback) # Immediately poll for next queue item poll_now_event.wait(self._polling_interval) continue From f7c356d142808b60d982a53fce26f44eed7ed10e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 22 May 2024 18:52:46 +1000 Subject: [PATCH 05/34] feat(app): iterate on processor split 2 - Use protocol to define callbacks, this allows them to have kwargs - Shuffle the profiler around a bit - Move `thread_limit` and `polling_interval` to `__init__`; `start` is called programmatically and will never get these args in practice --- invokeai/app/api/dependencies.py | 14 ++--- .../session_processor_base.py | 46 ++++++++++++++- .../session_processor_default.py | 57 ++++++++++--------- 3 files changed, 80 insertions(+), 37 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 1ff45be866f..87df06d569f 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -112,7 +112,7 @@ def on_before_run_node(invocation, queue_item): print("BEFORE RUN NODE", invocation.id) return True - def on_after_run_node(invocation, queue_item, outputs): + def on_after_run_node(invocation, queue_item, output): print("AFTER RUN NODE", invocation.id) return True @@ -124,17 +124,17 @@ def on_after_run_session(queue_item): print("AFTER RUN SESSION", queue_item.item_id) return True - def on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback): + def on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item=None): print("NON FATAL PROCESSOR ERROR", exc_value) return True session_processor = DefaultSessionProcessor( DefaultSessionRunner( - on_before_run_session, - on_before_run_node, - on_after_run_node, - on_node_error, - on_after_run_session, + on_before_run_session=on_before_run_session, + on_before_run_node=on_before_run_node, + on_after_run_node=on_after_run_node, + on_node_error=on_node_error, + on_after_run_session=on_after_run_session, ), on_non_fatal_processor_error, ) diff --git a/invokeai/app/services/session_processor/session_processor_base.py b/invokeai/app/services/session_processor/session_processor_base.py index 71408475188..bfae74e5fea 100644 --- a/invokeai/app/services/session_processor/session_processor_base.py +++ b/invokeai/app/services/session_processor/session_processor_base.py @@ -1,10 +1,13 @@ from abc import ABC, abstractmethod from threading import Event +from types import TracebackType +from typing import Optional, Protocol -from invokeai.app.invocations.baseinvocation import BaseInvocation +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem +from invokeai.app.util.profiler import Profiler class SessionRunnerBase(ABC): @@ -13,7 +16,7 @@ class SessionRunnerBase(ABC): """ @abstractmethod - def start(self, services: InvocationServices, cancel_event: Event) -> None: + def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None: """Starts the session runner""" pass @@ -51,3 +54,42 @@ def pause(self) -> SessionProcessorStatus: def get_status(self) -> SessionProcessorStatus: """Gets the status of the session processor""" pass + + +class OnBeforeRunNode(Protocol): + def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> bool: ... + + +class OnAfterRunNode(Protocol): + def __call__( + self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput + ) -> bool: ... + + +class OnNodeError(Protocol): + def __call__( + self, + invocation: BaseInvocation, + queue_item: SessionQueueItem, + exc_type: type, + exc_value: BaseException, + exc_traceback: TracebackType, + ) -> bool: ... + + +class OnBeforeRunSession(Protocol): + def __call__(self, queue_item: SessionQueueItem) -> bool: ... + + +class OnAfterRunSession(Protocol): + def __call__(self, queue_item: SessionQueueItem) -> bool: ... + + +class OnNonFatalProcessorError(Protocol): + def __call__( + self, + exc_type: type, + exc_value: BaseException, + exc_traceback: TracebackType, + queue_item: Optional[SessionQueueItem] = None, + ) -> bool: ... diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 33274dd97b1..4172e45d179 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -3,7 +3,7 @@ from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent from types import TracebackType -from typing import Callable, Optional, TypeAlias +from typing import Optional from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event as FastAPIEvent @@ -11,6 +11,14 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError +from invokeai.app.services.session_processor.session_processor_base import ( + OnAfterRunNode, + OnAfterRunSession, + OnBeforeRunNode, + OnBeforeRunSession, + OnNodeError, + OnNonFatalProcessorError, +) from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context @@ -20,13 +28,6 @@ from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase from .session_processor_common import SessionProcessorStatus -OnBeforeRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem], bool] -OnAfterRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, BaseInvocationOutput], bool] -OnNodeError: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, type, BaseException, TracebackType], bool] -OnBeforeRunSession: TypeAlias = Callable[[SessionQueueItem], bool] -OnAfterRunSession: TypeAlias = Callable[[SessionQueueItem], bool] -OnNonFatalProcessorError: TypeAlias = Callable[[Optional[SessionQueueItem], type, BaseException, TracebackType], bool] - def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str: return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) @@ -49,16 +50,18 @@ def __init__( self.on_node_error = on_node_error self.on_after_run_session = on_after_run_session - def start(self, services: InvocationServices, cancel_event: ThreadEvent): + def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): """Start the session runner""" self.services = services self.cancel_event = cancel_event + self.profiler = profiler - def run(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None): + def run(self, queue_item: SessionQueueItem): """Run the graph""" # Loop over invocations until the session is complete or canceled self._on_before_run_session(queue_item=queue_item) + while True: invocation = queue_item.session.next() if invocation is None or self.cancel_event.is_set(): @@ -66,20 +69,21 @@ def run(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) self.run_node(invocation, queue_item) if queue_item.session.is_complete() or self.cancel_event.is_set(): break + self._on_after_run_session(queue_item=queue_item) - def _on_before_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None: + def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: # If profiling is enabled, start the profiler - if profiler is not None: - profiler.start(profile_id=queue_item.session_id) + if self.profiler is not None: + self.profiler.start(profile_id=queue_item.session_id) if self.on_before_run_session: - self.on_before_run_session(queue_item) + self.on_before_run_session(queue_item=queue_item) - def _on_after_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None: + def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: # If we are profiling, stop the profiler and dump the profile & stats - if profiler: - profile_path = profiler.stop() + if self.profiler is not None: + profile_path = self.profiler.stop() stats_path = profile_path.with_suffix(".json") self.services.performance_statistics.dump_stats( graph_execution_state_id=queue_item.session.id, output_path=stats_path @@ -221,11 +225,15 @@ def __init__( self, session_runner: Optional[SessionRunnerBase] = None, on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None, + thread_limit: int = 1, + polling_interval: int = 1, ) -> None: super().__init__() self.session_runner = session_runner if session_runner else DefaultSessionRunner() self.on_non_fatal_processor_error = on_non_fatal_processor_error + self._thread_limit = thread_limit + self._polling_interval = polling_interval def _on_non_fatal_processor_error( self, @@ -243,14 +251,9 @@ def _on_non_fatal_processor_error( self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) if self.on_non_fatal_processor_error: - self.on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback) + self.on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item) - def start( - self, - invoker: Invoker, - thread_limit: int = 1, - polling_interval: int = 1, - ) -> None: + def start(self, invoker: Invoker) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None self._invocation: Optional[BaseInvocation] = None @@ -262,9 +265,7 @@ def start( local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) - self._thread_limit = thread_limit - self._thread_semaphore = BoundedSemaphore(thread_limit) - self._polling_interval = polling_interval + self._thread_semaphore = BoundedSemaphore(self._thread_limit) # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, # the profiler will create a new profile for each session. @@ -278,7 +279,7 @@ def start( else None ) - self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event) + self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler) self._thread = Thread( name="session_processor", target=self._process, From cb8e9e1c7b9bff56d677c879a17116f0fc2e2a1d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 22 May 2024 18:55:33 +1000 Subject: [PATCH 06/34] feat(app): make things in session runner private --- .../session_processor_default.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 4172e45d179..6b4b84e0998 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -52,9 +52,9 @@ def __init__( def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): """Start the session runner""" - self.services = services - self.cancel_event = cancel_event - self.profiler = profiler + self._services = services + self._cancel_event = cancel_event + self._profiler = profiler def run(self, queue_item: SessionQueueItem): """Run the graph""" @@ -64,33 +64,33 @@ def run(self, queue_item: SessionQueueItem): while True: invocation = queue_item.session.next() - if invocation is None or self.cancel_event.is_set(): + if invocation is None or self._cancel_event.is_set(): break self.run_node(invocation, queue_item) - if queue_item.session.is_complete() or self.cancel_event.is_set(): + if queue_item.session.is_complete() or self._cancel_event.is_set(): break self._on_after_run_session(queue_item=queue_item) def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: # If profiling is enabled, start the profiler - if self.profiler is not None: - self.profiler.start(profile_id=queue_item.session_id) + if self._profiler is not None: + self._profiler.start(profile_id=queue_item.session_id) if self.on_before_run_session: self.on_before_run_session(queue_item=queue_item) def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: # If we are profiling, stop the profiler and dump the profile & stats - if self.profiler is not None: - profile_path = self.profiler.stop() + if self._profiler is not None: + profile_path = self._profiler.stop() stats_path = profile_path.with_suffix(".json") - self.services.performance_statistics.dump_stats( + self._services.performance_statistics.dump_stats( graph_execution_state_id=queue_item.session.id, output_path=stats_path ) # Send complete event - self.services.events.emit_graph_execution_complete( + self._services.events.emit_graph_execution_complete( queue_batch_id=queue_item.batch_id, queue_item_id=queue_item.item_id, queue_id=queue_item.queue_id, @@ -100,8 +100,8 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor # we don't care about that - suppress the error. with suppress(GESStatsNotFoundError): - self.services.performance_statistics.log_stats(queue_item.session.id) - self.services.performance_statistics.reset_stats() + self._services.performance_statistics.log_stats(queue_item.session.id) + self._services.performance_statistics.reset_stats() if self.on_after_run_session: self.on_after_run_session(queue_item) @@ -109,7 +109,7 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run before a node is executed""" # Send starting event - self.services.events.emit_invocation_started( + self._services.events.emit_invocation_started( queue_batch_id=queue_item.batch_id, queue_item_id=queue_item.item_id, queue_id=queue_item.queue_id, @@ -126,7 +126,7 @@ def _on_after_run_node( ): """Run after a node is executed""" # Send complete event on successful runs - self.services.events.emit_invocation_complete( + self._services.events.emit_invocation_complete( queue_batch_id=queue_item.batch_id, queue_item_id=queue_item.item_id, queue_id=queue_item.queue_id, @@ -150,13 +150,13 @@ def _on_node_error( stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) queue_item.session.set_node_error(invocation.id, stacktrace) - self.services.logger.error( + self._services.logger.error( f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{exc_type}" ) - self.services.logger.error(stacktrace) + self._services.logger.error(stacktrace) # Send error event - self.services.events.emit_invocation_error( + self._services.events.emit_invocation_error( queue_batch_id=queue_item.session_id, queue_item_id=queue_item.item_id, queue_id=queue_item.queue_id, @@ -176,7 +176,7 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run a single node in the graph""" try: # Any unhandled exception is an invocation error & will fail the graph - with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id): + with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id): self._on_before_run_node(invocation, queue_item) data = InvocationContextData( @@ -186,12 +186,12 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): ) context = build_invocation_context( data=data, - services=self.services, - cancel_event=self.cancel_event, + services=self._services, + cancel_event=self._cancel_event, ) # Invoke the node - outputs = invocation.invoke_internal(context=context, services=self.services) + outputs = invocation.invoke_internal(context=context, services=self._services) # Save outputs and history queue_item.session.complete(invocation.id, outputs) From cef1585dfbbd12cd94c0f3c1a0a6efd019911362 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 22 May 2024 19:05:49 +1000 Subject: [PATCH 07/34] feat(app): support multiple processor lifecycle callbacks --- invokeai/app/api/dependencies.py | 12 ++--- .../session_processor_default.py | 54 ++++++++++--------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 87df06d569f..d9cefb0acff 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -130,13 +130,13 @@ def on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item= session_processor = DefaultSessionProcessor( DefaultSessionRunner( - on_before_run_session=on_before_run_session, - on_before_run_node=on_before_run_node, - on_after_run_node=on_after_run_node, - on_node_error=on_node_error, - on_after_run_session=on_after_run_session, + on_before_run_session_callbacks=[on_before_run_session], + on_before_run_node_callbacks=[on_before_run_node], + on_after_run_node_callbacks=[on_after_run_node], + on_node_error_callbacks=[on_node_error], + on_after_run_session_callbacks=[on_after_run_session], ), - on_non_fatal_processor_error, + on_non_fatal_processor_error_callbacks=[on_non_fatal_processor_error], ) session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 6b4b84e0998..74e7dd3debc 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -38,17 +38,17 @@ class DefaultSessionRunner(SessionRunnerBase): def __init__( self, - on_before_run_session: Optional[OnBeforeRunSession] = None, - on_before_run_node: Optional[OnBeforeRunNode] = None, - on_after_run_node: Optional[OnAfterRunNode] = None, - on_node_error: Optional[OnNodeError] = None, - on_after_run_session: Optional[OnAfterRunSession] = None, + on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None, + on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None, + on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None, + on_node_error_callbacks: Optional[list[OnNodeError]] = None, + on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None, ): - self.on_before_run_session = on_before_run_session - self.on_before_run_node = on_before_run_node - self.on_after_run_node = on_after_run_node - self.on_node_error = on_node_error - self.on_after_run_session = on_after_run_session + self._on_before_run_session_callbacks = on_before_run_session_callbacks or [] + self._on_before_run_node_callbacks = on_before_run_node_callbacks or [] + self._on_after_run_node_callbacks = on_after_run_node_callbacks or [] + self._on_node_error_callbacks = on_node_error_callbacks or [] + self._on_after_run_session_callbacks = on_after_run_session_callbacks or [] def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): """Start the session runner""" @@ -77,8 +77,8 @@ def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: if self._profiler is not None: self._profiler.start(profile_id=queue_item.session_id) - if self.on_before_run_session: - self.on_before_run_session(queue_item=queue_item) + for callback in self._on_before_run_session_callbacks: + callback(queue_item=queue_item) def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: # If we are profiling, stop the profiler and dump the profile & stats @@ -103,8 +103,8 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: self._services.performance_statistics.log_stats(queue_item.session.id) self._services.performance_statistics.reset_stats() - if self.on_after_run_session: - self.on_after_run_session(queue_item) + for callback in self._on_after_run_session_callbacks: + callback(queue_item) def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run before a node is executed""" @@ -117,9 +117,9 @@ def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQue node=invocation.model_dump(), source_node_id=queue_item.session.prepared_source_mapping[invocation.id], ) - # And run lifecycle callbacks - if self.on_before_run_node is not None: - self.on_before_run_node(invocation, queue_item) + + for callback in self._on_before_run_node_callbacks: + callback(invocation, queue_item) def _on_after_run_node( self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput @@ -135,9 +135,9 @@ def _on_after_run_node( source_node_id=queue_item.session.prepared_source_mapping[invocation.id], result=outputs.model_dump(), ) - # And run lifecycle callbacks - if self.on_after_run_node is not None: - self.on_after_run_node(invocation, queue_item, outputs) + + for callback in self._on_after_run_node_callbacks: + callback(invocation, queue_item, outputs) def _on_node_error( self, @@ -169,8 +169,8 @@ def _on_node_error( project_id=None, ) - if self.on_node_error is not None: - self.on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback) + for callback in self._on_node_error_callbacks: + callback(invocation, queue_item, exc_type, exc_value, exc_traceback) def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run a single node in the graph""" @@ -213,6 +213,7 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): # loop go to its next iteration, and the cancel event will be handled correctly. pass except Exception as e: + # Must extract the exception traceback here to not lose its stacktrace when we change scope exc_type = type(e) exc_value = e exc_traceback = e.__traceback__ @@ -224,14 +225,14 @@ class DefaultSessionProcessor(SessionProcessorBase): def __init__( self, session_runner: Optional[SessionRunnerBase] = None, - on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None, + on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None, thread_limit: int = 1, polling_interval: int = 1, ) -> None: super().__init__() self.session_runner = session_runner if session_runner else DefaultSessionRunner() - self.on_non_fatal_processor_error = on_non_fatal_processor_error + self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or [] self._thread_limit = thread_limit self._polling_interval = polling_interval @@ -250,8 +251,8 @@ def _on_non_fatal_processor_error( self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) - if self.on_non_fatal_processor_error: - self.on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item) + for callback in self._on_non_fatal_processor_error_callbacks: + callback(exc_type, exc_value, exc_traceback, queue_item) def start(self, invoker: Invoker) -> None: self._invoker: Invoker = invoker @@ -377,6 +378,7 @@ def _process( self.session_runner.run(queue_item=self._queue_item) except Exception as e: + # Must extract the exception traceback here to not lose its stacktrace when we change scope exc_type = type(e) exc_value = e exc_traceback = e.__traceback__ From eff359625a16821d3499aab951f178c74f34252f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 22 May 2024 19:07:57 +1000 Subject: [PATCH 08/34] tidy(app): rearrange proccessor --- .../session_processor_default.py | 134 +++++++++--------- 1 file changed, 68 insertions(+), 66 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 74e7dd3debc..6e60d208530 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -30,6 +30,8 @@ def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str: + """Formats a stacktrace as a string""" + return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) @@ -72,6 +74,54 @@ def run(self, queue_item: SessionQueueItem): self._on_after_run_session(queue_item=queue_item) + def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + """Run a single node in the graph""" + try: + # Any unhandled exception is an invocation error & will fail the graph + with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id): + self._on_before_run_node(invocation, queue_item) + + data = InvocationContextData( + invocation=invocation, + source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], + queue_item=queue_item, + ) + context = build_invocation_context( + data=data, + services=self._services, + cancel_event=self._cancel_event, + ) + + # Invoke the node + outputs = invocation.invoke_internal(context=context, services=self._services) + # Save outputs and history + queue_item.session.complete(invocation.id, outputs) + + self._on_after_run_node(invocation, queue_item, outputs) + + except KeyboardInterrupt: + # TODO(MM2): Create an event for this + pass + except CanceledException: + # When the user cancels the graph, we first set the cancel event. The event is checked + # between invocations, in this loop. Some invocations are long-running, and we need to + # be able to cancel them mid-execution. + # + # For example, denoising is a long-running invocation with many steps. A step callback + # is executed after each step. This step callback checks if the canceled event is set, + # then raises a CanceledException to stop execution immediately. + # + # When we get a CanceledException, we don't need to do anything - just pass and let the + # loop go to its next iteration, and the cancel event will be handled correctly. + pass + except Exception as e: + # Must extract the exception traceback here to not lose its stacktrace when we change scope + exc_type = type(e) + exc_value = e + exc_traceback = e.__traceback__ + assert exc_traceback is not None + self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback) + def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: # If profiling is enabled, start the profiler if self._profiler is not None: @@ -172,54 +222,6 @@ def _on_node_error( for callback in self._on_node_error_callbacks: callback(invocation, queue_item, exc_type, exc_value, exc_traceback) - def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): - """Run a single node in the graph""" - try: - # Any unhandled exception is an invocation error & will fail the graph - with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id): - self._on_before_run_node(invocation, queue_item) - - data = InvocationContextData( - invocation=invocation, - source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], - queue_item=queue_item, - ) - context = build_invocation_context( - data=data, - services=self._services, - cancel_event=self._cancel_event, - ) - - # Invoke the node - outputs = invocation.invoke_internal(context=context, services=self._services) - # Save outputs and history - queue_item.session.complete(invocation.id, outputs) - - self._on_after_run_node(invocation, queue_item, outputs) - - except KeyboardInterrupt: - # TODO(MM2): Create an event for this - pass - except CanceledException: - # When the user cancels the graph, we first set the cancel event. The event is checked - # between invocations, in this loop. Some invocations are long-running, and we need to - # be able to cancel them mid-execution. - # - # For example, denoising is a long-running invocation with many steps. A step callback - # is executed after each step. This step callback checks if the canceled event is set, - # then raises a CanceledException to stop execution immediately. - # - # When we get a CanceledException, we don't need to do anything - just pass and let the - # loop go to its next iteration, and the cancel event will be handled correctly. - pass - except Exception as e: - # Must extract the exception traceback here to not lose its stacktrace when we change scope - exc_type = type(e) - exc_value = e - exc_traceback = e.__traceback__ - assert exc_traceback is not None - self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback) - class DefaultSessionProcessor(SessionProcessorBase): def __init__( @@ -236,24 +238,6 @@ def __init__( self._thread_limit = thread_limit self._polling_interval = polling_interval - def _on_non_fatal_processor_error( - self, - queue_item: Optional[SessionQueueItem], - exc_type: type, - exc_value: BaseException, - exc_traceback: TracebackType, - ) -> None: - stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) - # Non-fatal error in processor - self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}") - # Cancel the queue item - if queue_item is not None: - self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) - self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) - - for callback in self._on_non_fatal_processor_error_callbacks: - callback(exc_type, exc_value, exc_traceback, queue_item) - def start(self, invoker: Invoker) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None @@ -396,3 +380,21 @@ def _process( poll_now_event.clear() self._queue_item = None self._thread_semaphore.release() + + def _on_non_fatal_processor_error( + self, + queue_item: Optional[SessionQueueItem], + exc_type: type, + exc_value: BaseException, + exc_traceback: TracebackType, + ) -> None: + stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) + # Non-fatal error in processor + self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}") + # Cancel the queue item + if queue_item is not None: + self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) + self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) + + for callback in self._on_non_fatal_processor_error_callbacks: + callback(exc_type, exc_value, exc_traceback, queue_item) From b1f819ae8d3a09a3389e81fc4b97a038f3adf7c8 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 22 May 2024 19:08:43 +1000 Subject: [PATCH 09/34] tidy(app): "outputs" -> "output" --- .../session_processor/session_processor_default.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 6e60d208530..852b330f033 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -93,11 +93,11 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): ) # Invoke the node - outputs = invocation.invoke_internal(context=context, services=self._services) - # Save outputs and history - queue_item.session.complete(invocation.id, outputs) + output = invocation.invoke_internal(context=context, services=self._services) + # Save output and history + queue_item.session.complete(invocation.id, output) - self._on_after_run_node(invocation, queue_item, outputs) + self._on_after_run_node(invocation, queue_item, output) except KeyboardInterrupt: # TODO(MM2): Create an event for this @@ -172,7 +172,7 @@ def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQue callback(invocation, queue_item) def _on_after_run_node( - self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput + self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput ): """Run after a node is executed""" # Send complete event on successful runs @@ -183,11 +183,11 @@ def _on_after_run_node( graph_execution_state_id=queue_item.session.id, node=invocation.model_dump(), source_node_id=queue_item.session.prepared_source_mapping[invocation.id], - result=outputs.model_dump(), + result=output.model_dump(), ) for callback in self._on_after_run_node_callbacks: - callback(invocation, queue_item, outputs) + callback(invocation, queue_item, output) def _on_node_error( self, From d30c1ad6dc55be22b1c6eca3f48e102fa914dda7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 22 May 2024 19:37:55 +1000 Subject: [PATCH 10/34] docs(app): explain why errors are handled poorly --- .../session_processor_default.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 852b330f033..636304ef861 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -60,11 +60,23 @@ def start(self, services: InvocationServices, cancel_event: ThreadEvent, profile def run(self, queue_item: SessionQueueItem): """Run the graph""" - # Loop over invocations until the session is complete or canceled + # Exceptions raised outside `run_node` are handled by the processor. self._on_before_run_session(queue_item=queue_item) + # Loop over invocations until the session is complete or canceled while True: + # TODO(psyche): Sessions only support errors on nodes, not on the session itself. When an error occurs outside + # node execution, it bubbles up to the processor where it is treated as a queue item error. + # + # Nodes are pydantic models. When we prepare a node in `session.next()`, we set its inputs. This can cause a + # pydantic validation error. For example, consider a resize image node which has a constraint on its `width` + # input field - it must be greater than zero. During preparation, if the width is set to zero, pydantic will + # raise a validation error. + # + # When this happens, it breaks the flow before `invocation` is set. We can't set an error on the invocation + # because we didn't get far enough to get it - we don't know its id. Hence, we just set it as a queue item error. + invocation = queue_item.session.next() if invocation is None or self._cancel_event.is_set(): break @@ -77,7 +89,7 @@ def run(self, queue_item: SessionQueueItem): def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run a single node in the graph""" try: - # Any unhandled exception is an invocation error & will fail the graph + # Any unhandled exception in this scope is an invocation error & will fail the graph with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id): self._on_before_run_node(invocation, queue_item) From df5457231fa6b60f6ff38c31a39c0ed085dade6d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 22 May 2024 20:45:05 +1000 Subject: [PATCH 11/34] feat(app): handle preparation errors as node errors We were not handling node preparation errors as node errors before. Here's the explanation, copied from a comment that is no longer required: --- TODO(psyche): Sessions only support errors on nodes, not on the session itself. When an error occurs outside node execution, it bubbles up to the processor where it is treated as a queue item error. Nodes are pydantic models. When we prepare a node in `session.next()`, we set its inputs. This can cause a pydantic validation error. For example, consider a resize image node which has a constraint on its `width` input field - it must be greater than zero. During preparation, if the width is set to zero, pydantic will raise a validation error. When this happens, it breaks the flow before `invocation` is set. We can't set an error on the invocation because we didn't get far enough to get it - we don't know its id. Hence, we just set it as a queue item error. --- This change wraps the node preparation step with exception handling. A new `NodeInputError` exception is raised when there is a validation error. This error has the node (in the state it was in just prior to the error) and an identifier of the input that failed. This allows us to mark the node that failed preparation as errored, correctly making such errors _node_ errors and not _processor_ errors. It's much easier to diagnose these situations. The error messages look like this: > Node b5ac87c6-0678-4b8c-96b9-d215aee12175 has invalid incoming input for height Some of the exception handling logic is cleaned up. --- .../session_processor_default.py | 21 +++++----- invokeai/app/services/shared/graph.py | 39 ++++++++++++++++++- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 636304ef861..734cec1d0d7 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -21,6 +21,7 @@ ) from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem +from invokeai.app.services.shared.graph import NodeInputError from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler @@ -66,18 +67,16 @@ def run(self, queue_item: SessionQueueItem): # Loop over invocations until the session is complete or canceled while True: - # TODO(psyche): Sessions only support errors on nodes, not on the session itself. When an error occurs outside - # node execution, it bubbles up to the processor where it is treated as a queue item error. - # - # Nodes are pydantic models. When we prepare a node in `session.next()`, we set its inputs. This can cause a - # pydantic validation error. For example, consider a resize image node which has a constraint on its `width` - # input field - it must be greater than zero. During preparation, if the width is set to zero, pydantic will - # raise a validation error. - # - # When this happens, it breaks the flow before `invocation` is set. We can't set an error on the invocation - # because we didn't get far enough to get it - we don't know its id. Hence, we just set it as a queue item error. + try: + invocation = queue_item.session.next() + # Anything other than a `NodeInputError` is handled as a processor error + except NodeInputError as e: + # Must extract the exception traceback here to not lose its stacktrace when we change scope + traceback = e.__traceback__ + assert traceback is not None + self._on_node_error(e.node, queue_item, type(e), e, traceback) + break - invocation = queue_item.session.next() if invocation is None or self._cancel_event.is_set(): break self.run_node(invocation, queue_item) diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index cc2ea5cedb3..8508d2484c8 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -8,6 +8,7 @@ from pydantic import ( BaseModel, GetJsonSchemaHandler, + ValidationError, field_validator, ) from pydantic.fields import Field @@ -190,6 +191,39 @@ class UnknownGraphValidationError(ValueError): pass +class NodeInputError(ValueError): + """Raised when a node fails preparation. This occurs when a node's inputs are being set from its incomers, but an + input fails validation. + + Attributes: + node: The node that failed preparation. Note: only successfully set fields will be accurate. Review the error to + determine which field caused the failure. + """ + + def __init__(self, node: BaseInvocation, e: ValidationError): + self.original_error = e + self.node = node + # When preparing a node, we set each input one-at-a-time. We may thus safely assume that the first error + # represents the first input that failed. + self.failed_input = loc_to_dot_sep(e.errors()[0]["loc"]) + super().__init__(f"Node {node.id} has invalid incoming input for {self.failed_input}") + + +def loc_to_dot_sep(loc: tuple[Union[str, int], ...]) -> str: + """Helper to pretty-print pydantic error locations as dot-separated strings. + Taken from https://docs.pydantic.dev/latest/errors/errors/#customize-error-messages + """ + path = "" + for i, x in enumerate(loc): + if isinstance(x, str): + if i > 0: + path += "." + path += x + else: + path += f"[{x}]" + return path + + @invocation_output("iterate_output") class IterateInvocationOutput(BaseInvocationOutput): """Used to connect iteration outputs. Will be expanded to a specific output.""" @@ -821,7 +855,10 @@ def next(self) -> Optional[BaseInvocation]: # Get values from edges if next_node is not None: - self._prepare_inputs(next_node) + try: + self._prepare_inputs(next_node) + except ValidationError as e: + raise NodeInputError(next_node, e) # If next is still none, there's no next node, return None return next_node From 80905ff3ea09b60e2439e23bcbf8a00b876af0a1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 22 May 2024 20:45:38 +1000 Subject: [PATCH 12/34] fix(app): fix logging of error classes instead of class names --- .../session_processor/session_processor_default.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 734cec1d0d7..c4cffd998e5 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -127,11 +127,9 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): pass except Exception as e: # Must extract the exception traceback here to not lose its stacktrace when we change scope - exc_type = type(e) - exc_value = e exc_traceback = e.__traceback__ assert exc_traceback is not None - self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback) + self._on_node_error(invocation, queue_item, type(e), e, exc_traceback) def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: # If profiling is enabled, start the profiler @@ -212,7 +210,7 @@ def _on_node_error( queue_item.session.set_node_error(invocation.id, stacktrace) self._services.logger.error( - f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{exc_type}" + f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {exc_type.__name__}" ) self._services.logger.error(stacktrace) @@ -374,11 +372,9 @@ def _process( except Exception as e: # Must extract the exception traceback here to not lose its stacktrace when we change scope - exc_type = type(e) - exc_value = e exc_traceback = e.__traceback__ assert exc_traceback is not None - self._on_non_fatal_processor_error(self._queue_item, exc_type, exc_value, exc_traceback) + self._on_non_fatal_processor_error(self._queue_item, type(e), e, exc_traceback) # Immediately poll for next queue item poll_now_event.wait(self._polling_interval) continue @@ -401,7 +397,8 @@ def _on_non_fatal_processor_error( ) -> None: stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) # Non-fatal error in processor - self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}") + self._invoker.services.logger.error(f"Non-fatal error in session processor: {exc_type.__name__}") + self._invoker.services.logger.error(stacktrace) # Cancel the queue item if queue_item is not None: self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) From 23b05344a3442a564b063343e893567cc03b852e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 23 May 2024 15:57:30 +1000 Subject: [PATCH 13/34] feat(processor): get user/project from queue item w/ fallback --- .../services/session_processor/session_processor_default.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index c4cffd998e5..0c04b4190ba 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -224,8 +224,8 @@ def _on_node_error( source_node_id=queue_item.session.prepared_source_mapping[invocation.id], error_type=exc_type.__name__, error=stacktrace, - user_id=None, - project_id=None, + user_id=getattr(queue_item, 'user_id', None), + project_id=getattr(queue_item, 'project_id', None), ) for callback in self._on_node_error_callbacks: From a55b2f09e2b8a349aaa53be99c26246003af498e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 08:35:19 +1000 Subject: [PATCH 14/34] chore: ruff --- .../services/session_processor/session_processor_default.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 0c04b4190ba..1d404baa515 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -224,8 +224,8 @@ def _on_node_error( source_node_id=queue_item.session.prepared_source_mapping[invocation.id], error_type=exc_type.__name__, error=stacktrace, - user_id=getattr(queue_item, 'user_id', None), - project_id=getattr(queue_item, 'project_id', None), + user_id=getattr(queue_item, "user_id", None), + project_id=getattr(queue_item, "project_id", None), ) for callback in self._on_node_error_callbacks: From 7652fbc2e90e9f630b267d85fa2859e0e10f8c28 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 09:26:33 +1000 Subject: [PATCH 15/34] fix(processor): restore missing update of session --- .../services/session_processor/session_processor_default.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 1d404baa515..cddb7cdc037 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -148,6 +148,9 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: graph_execution_state_id=queue_item.session.id, output_path=stats_path ) + # Update the queue item with the completed session + self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) + # Send complete event self._services.events.emit_graph_execution_complete( queue_batch_id=queue_item.batch_id, @@ -399,9 +402,10 @@ def _on_non_fatal_processor_error( # Non-fatal error in processor self._invoker.services.logger.error(f"Non-fatal error in session processor: {exc_type.__name__}") self._invoker.services.logger.error(stacktrace) - # Cancel the queue item if queue_item is not None: + # Update the queue item with the completed session self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) + # And cancel the queue item with an error self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) for callback in self._on_non_fatal_processor_error_callbacks: From 0e81e7b4605adb0a8c847c982b7975e989afab5e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 23 May 2024 15:10:05 +1000 Subject: [PATCH 16/34] feat(db): add `error_type`, `error_message`, rename `error` -> `error_traceback` to `session_queue` table --- .../app/services/shared/sqlite/sqlite_util.py | 2 ++ .../migrations/migration_10.py | 35 +++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 1eed0b44092..cadf09f4575 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -12,6 +12,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -41,6 +42,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_7()) migrator.register_migration(build_migration_8(app_config=config)) migrator.register_migration(build_migration_9()) + migrator.register_migration(build_migration_10()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py new file mode 100644 index 00000000000..ce2cd2e965e --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py @@ -0,0 +1,35 @@ +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration10Callback: + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._update_error_cols(cursor) + + def _update_error_cols(self, cursor: sqlite3.Cursor) -> None: + """ + - Adds `error_type` and `error_message` columns to the session queue table. + - Renames the `error` column to `error_traceback`. + """ + + cursor.execute("ALTER TABLE session_queue ADD COLUMN error_type TEXT;") + cursor.execute("ALTER TABLE session_queue ADD COLUMN error_message TEXT;") + cursor.execute("ALTER TABLE session_queue RENAME COLUMN error TO error_traceback;") + + +def build_migration_10() -> Migration: + """ + Build the migration from database version 9 to 10. + + This migration does the following: + - Adds `error_type` and `error_message` columns to the session queue table. + - Renames the `error` column to `error_traceback`. + """ + migration_10 = Migration( + from_version=9, + to_version=10, + callback=Migration10Callback(), + ) + + return migration_10 From d6696a7b9793a5c9dd155b9baadfc8fa334d0b5f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 23 May 2024 15:15:53 +1000 Subject: [PATCH 17/34] feat(queue): session queue error handling - Add handling for new error columns `error_type`, `error_message`, `error_traceback`. - Update queue item model to include the new data. The `error_traceback` field has an alias of `error` for backwards compatibility. - Add `fail_queue_item` method. This was previously handled by `cancel_queue_item`. Splitting this functionality makes failing a queue item a bit more explicit. We also don't need to handle multiple optional error args. - --- .../session_queue/session_queue_base.py | 7 ++- .../session_queue/session_queue_common.py | 19 ++++++- .../session_queue/session_queue_sqlite.py | 56 ++++++++++++++++--- 3 files changed, 70 insertions(+), 12 deletions(-) diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index f46463f528f..8b21998f193 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -74,10 +74,15 @@ def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus: pass @abstractmethod - def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: + def cancel_queue_item(self, item_id: int) -> SessionQueueItem: """Cancels a session queue item""" pass + @abstractmethod + def fail_queue_item(self, item_id: int, error_type: str, error_message: str, error_traceback: str) -> SessionQueueItem: + """Fails a session queue item""" + pass + @abstractmethod def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: """Cancels all queue items with matching batch IDs""" diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 94db6999c2b..7f4601eba73 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -3,7 +3,16 @@ from itertools import chain, product from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast -from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator +from pydantic import ( + AliasChoices, + BaseModel, + ConfigDict, + Field, + StrictStr, + TypeAdapter, + field_validator, + model_validator, +) from pydantic_core import to_jsonable_python from invokeai.app.invocations.baseinvocation import BaseInvocation @@ -189,7 +198,13 @@ class SessionQueueItemWithoutGraph(BaseModel): session_id: str = Field( description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed." ) - error: Optional[str] = Field(default=None, description="The error message if this queue item errored") + error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored") + error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored") + error_traceback: Optional[str] = Field( + default=None, + description="The error traceback if this queue item errored", + validation_alias=AliasChoices("error_traceback", "error"), + ) created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created") updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated") started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started") diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 87c22c496fd..dfd00a78094 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -82,10 +82,18 @@ async def _handle_complete_event(self, event: FastAPIEvent) -> None: async def _handle_error_event(self, event: FastAPIEvent) -> None: try: item_id = event[1]["data"]["queue_item_id"] - error = event[1]["data"]["error"] + error_type = event[1]["data"]["error_type"] + error_message = event[1]["data"]["error_message"] + error_traceback = event[1]["data"]["error_traceback"] queue_item = self.get_queue_item(item_id) # always set to failed if have an error, even if previously the item was marked completed or canceled - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error) + queue_item = self._set_queue_item_status( + item_id=queue_item.item_id, + status="failed", + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) except SessionQueueItemNotFoundError: return @@ -272,17 +280,22 @@ def get_current(self, queue_id: str) -> Optional[SessionQueueItem]: return SessionQueueItem.queue_item_from_dict(dict(result)) def _set_queue_item_status( - self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None + self, + item_id: int, + status: QUEUE_ITEM_STATUS, + error_type: Optional[str] = None, + error_message: Optional[str] = None, + error_traceback: Optional[str] = None, ) -> SessionQueueItem: try: self.__lock.acquire() self.__cursor.execute( """--sql UPDATE session_queue - SET status = ?, error = ? + SET status = ?, error_type = ?, error_message = ?, error_traceback = ? WHERE item_id = ? """, - (status, error, item_id), + (status, error_type, error_message, error_traceback, item_id), ) self.__conn.commit() except Exception: @@ -425,11 +438,34 @@ def prune(self, queue_id: str) -> PruneResult: self.__lock.release() return PruneResult(deleted=count) - def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: + def cancel_queue_item(self, item_id: int) -> SessionQueueItem: queue_item = self.get_queue_item(item_id) if queue_item.status not in ["canceled", "failed", "completed"]: - status = "failed" if error is not None else "canceled" - queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here + queue_item = self._set_queue_item_status(item_id=item_id, status="canceled") + self.__invoker.services.events.emit_session_canceled( + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + queue_batch_id=queue_item.batch_id, + graph_execution_state_id=queue_item.session_id, + ) + return queue_item + + def fail_queue_item( + self, + item_id: int, + error_type: str, + error_message: str, + error_traceback: str, + ) -> SessionQueueItem: + queue_item = self.get_queue_item(item_id) + if queue_item.status not in ["canceled", "failed", "completed"]: + queue_item = self._set_queue_item_status( + item_id=item_id, + status="failed", + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) self.__invoker.services.events.emit_session_canceled( queue_item_id=queue_item.item_id, queue_id=queue_item.queue_id, @@ -602,7 +638,9 @@ def list_queue_items( status, priority, field_values, - error, + error_type, + error_message, + error_traceback, created_at, updated_at, completed_at, From 6a34176376ec318987e1386c097983bc425ac4fe Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 23 May 2024 15:19:34 +1000 Subject: [PATCH 18/34] feat(events): add enriched errors to events --- invokeai/app/services/events/events_base.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index aa91cdaec8f..1ddda2921d4 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -121,7 +121,8 @@ def emit_invocation_error( node: dict, source_node_id: str, error_type: str, - error: str, + error_message: str, + error_traceback: str, user_id: str | None, project_id: str | None, ) -> None: @@ -136,7 +137,8 @@ def emit_invocation_error( "node": node, "source_node_id": source_node_id, "error_type": error_type, - "error": error, + "error_message": error_message, + "error_traceback": error_traceback, "user_id": user_id, "project_id": project_id, }, @@ -257,7 +259,9 @@ def emit_queue_item_status_changed( "status": session_queue_item.status, "batch_id": session_queue_item.batch_id, "session_id": session_queue_item.session_id, - "error": session_queue_item.error, + "error_type": session_queue_item.error_type, + "error_message": session_queue_item.error_message, + "error_traceback": session_queue_item.error_traceback, "created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None, "updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None, "started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None, From db0ef8d316bd400d51dad33f8c1f890829edcf4f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 23 May 2024 15:20:22 +1000 Subject: [PATCH 19/34] feat(processor): update enriched errors & `fail_queue_item()` --- .../session_processor_base.py | 15 +-- .../session_processor_default.py | 121 +++++++++++------- 2 files changed, 83 insertions(+), 53 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_base.py b/invokeai/app/services/session_processor/session_processor_base.py index bfae74e5fea..1436627a9ea 100644 --- a/invokeai/app/services/session_processor/session_processor_base.py +++ b/invokeai/app/services/session_processor/session_processor_base.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from threading import Event -from types import TracebackType from typing import Optional, Protocol from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput @@ -71,9 +70,9 @@ def __call__( self, invocation: BaseInvocation, queue_item: SessionQueueItem, - exc_type: type, - exc_value: BaseException, - exc_traceback: TracebackType, + error_type: str, + error_message: str, + error_traceback: str, ) -> bool: ... @@ -88,8 +87,8 @@ def __call__(self, queue_item: SessionQueueItem) -> bool: ... class OnNonFatalProcessorError(Protocol): def __call__( self, - exc_type: type, - exc_value: BaseException, - exc_traceback: TracebackType, - queue_item: Optional[SessionQueueItem] = None, + queue_item: Optional[SessionQueueItem], + error_type: str, + error_message: str, + error_traceback: str, ) -> bool: ... diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index cddb7cdc037..49277a105d7 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -2,7 +2,6 @@ from contextlib import suppress from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent -from types import TracebackType from typing import Optional from fastapi_events.handlers.local import local_handler @@ -30,12 +29,6 @@ from .session_processor_common import SessionProcessorStatus -def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str: - """Formats a stacktrace as a string""" - - return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) - - class DefaultSessionRunner(SessionRunnerBase): """Processes a single session's invocations""" @@ -71,10 +64,16 @@ def run(self, queue_item: SessionQueueItem): invocation = queue_item.session.next() # Anything other than a `NodeInputError` is handled as a processor error except NodeInputError as e: - # Must extract the exception traceback here to not lose its stacktrace when we change scope - traceback = e.__traceback__ - assert traceback is not None - self._on_node_error(e.node, queue_item, type(e), e, traceback) + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._on_node_error( + invocation=e.node, + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) break if invocation is None or self._cancel_event.is_set(): @@ -126,10 +125,16 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): # loop go to its next iteration, and the cancel event will be handled correctly. pass except Exception as e: - # Must extract the exception traceback here to not lose its stacktrace when we change scope - exc_traceback = e.__traceback__ - assert exc_traceback is not None - self._on_node_error(invocation, queue_item, type(e), e, exc_traceback) + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._on_node_error( + invocation=invocation, + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: # If profiling is enabled, start the profiler @@ -166,7 +171,7 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: self._services.performance_statistics.reset_stats() for callback in self._on_after_run_session_callbacks: - callback(queue_item) + callback(queue_item=queue_item) def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run before a node is executed""" @@ -181,7 +186,7 @@ def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQue ) for callback in self._on_before_run_node_callbacks: - callback(invocation, queue_item) + callback(invocation=invocation, queue_item=queue_item) def _on_after_run_node( self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput @@ -199,23 +204,23 @@ def _on_after_run_node( ) for callback in self._on_after_run_node_callbacks: - callback(invocation, queue_item, output) + callback(invocation=invocation, queue_item=queue_item, output=output) def _on_node_error( self, invocation: BaseInvocation, queue_item: SessionQueueItem, - exc_type: type, - exc_value: BaseException, - exc_traceback: TracebackType, + error_type: str, + error_message: str, + error_traceback: str, ): - stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) - - queue_item.session.set_node_error(invocation.id, stacktrace) + # Node errors do not get the full traceback. Only the queue item gets the full traceback. + node_error = f"{error_type}: {error_message}" + queue_item.session.set_node_error(invocation.id, node_error) self._services.logger.error( - f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {exc_type.__name__}" + f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}" ) - self._services.logger.error(stacktrace) + self._services.logger.error(error_traceback) # Send error event self._services.events.emit_invocation_error( @@ -225,14 +230,21 @@ def _on_node_error( graph_execution_state_id=queue_item.session.id, node=invocation.model_dump(), source_node_id=queue_item.session.prepared_source_mapping[invocation.id], - error_type=exc_type.__name__, - error=stacktrace, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, user_id=getattr(queue_item, "user_id", None), project_id=getattr(queue_item, "project_id", None), ) for callback in self._on_node_error_callbacks: - callback(invocation, queue_item, exc_type, exc_value, exc_traceback) + callback( + invocation=invocation, + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) class DefaultSessionProcessor(SessionProcessorBase): @@ -374,16 +386,25 @@ def _process( self.session_runner.run(queue_item=self._queue_item) except Exception as e: - # Must extract the exception traceback here to not lose its stacktrace when we change scope - exc_traceback = e.__traceback__ - assert exc_traceback is not None - self._on_non_fatal_processor_error(self._queue_item, type(e), e, exc_traceback) - # Immediately poll for next queue item + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._on_non_fatal_processor_error( + queue_item=self._queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) + # Wait for next polling interval or event to try again poll_now_event.wait(self._polling_interval) continue - except Exception: + except Exception as e: # Fatal error in processor, log and pass - we're done here - self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}") + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}") + self._invoker.services.logger.error(error_traceback) pass finally: stop_event.clear() @@ -394,19 +415,29 @@ def _process( def _on_non_fatal_processor_error( self, queue_item: Optional[SessionQueueItem], - exc_type: type, - exc_value: BaseException, - exc_traceback: TracebackType, + error_type: str, + error_message: str, + error_traceback: str, ) -> None: - stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) # Non-fatal error in processor - self._invoker.services.logger.error(f"Non-fatal error in session processor: {exc_type.__name__}") - self._invoker.services.logger.error(stacktrace) + self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}") + self._invoker.services.logger.error(error_traceback) + if queue_item is not None: # Update the queue item with the completed session self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) - # And cancel the queue item with an error - self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) + # Fail the queue item + self._invoker.services.session_queue.fail_queue_item( + item_id=queue_item.item_id, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) for callback in self._on_non_fatal_processor_error_callbacks: - callback(exc_type, exc_value, exc_traceback, queue_item) + callback( + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) From 19227fe4e619279875f1a84040f6673c81c17a94 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 23 May 2024 15:20:41 +1000 Subject: [PATCH 20/34] feat(app): update test event callbacks --- invokeai/app/api/dependencies.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index d9cefb0acff..b3c2acfb947 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -116,7 +116,7 @@ def on_after_run_node(invocation, queue_item, output): print("AFTER RUN NODE", invocation.id) return True - def on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback): + def on_node_error(invocation, queue_item, error_type, error_message, error_traceback): print("NODE ERROR", invocation.id) return True @@ -124,8 +124,8 @@ def on_after_run_session(queue_item): print("AFTER RUN SESSION", queue_item.item_id) return True - def on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item=None): - print("NON FATAL PROCESSOR ERROR", exc_value) + def on_non_fatal_processor_error(queue_item, error_type, error_message, error_traceback): + print("NON FATAL PROCESSOR ERROR", error_message) return True session_processor = DefaultSessionProcessor( From 9a4c16734287042811295e2708cdfadffab4264a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 23 May 2024 15:20:51 +1000 Subject: [PATCH 21/34] chore(ui): typegen --- .../frontend/web/src/services/api/schema.ts | 243 ++++++++++-------- 1 file changed, 130 insertions(+), 113 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index cb3d11c06b9..33e795fd466 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1175,11 +1175,8 @@ export type components = { * Format: binary */ file: Blob; - /** - * Metadata - * @description The metadata to associate with the image - */ - metadata?: Record | null; + /** @description The metadata to associate with the image */ + metadata?: components["schemas"]["JsonValue"] | null; }; /** * Boolean Collection Primitive @@ -4261,7 +4258,7 @@ export type components = { * @description The nodes in this graph */ nodes: { - [key: string]: components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"]; + [key: string]: components["schemas"]["RandomRangeInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["T2IAdapterInvocation"]; }; /** * Edges @@ -4298,7 +4295,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["NoiseOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["String2Output"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["MetadataOutput"]; + [key: string]: components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["String2Output"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["LoRALoaderOutput"]; }; /** * Errors @@ -9878,10 +9875,20 @@ export type components = { */ session_id: string; /** - * Error + * Error Type + * @description The error type if this queue item errored + */ + error_type?: string | null; + /** + * Error Message * @description The error message if this queue item errored */ - error?: string | null; + error_message?: string | null; + /** + * Error Traceback + * @description The error traceback if this queue item errored + */ + error_traceback?: string | null; /** * Created At * @description When this queue item was created @@ -9948,10 +9955,20 @@ export type components = { */ session_id: string; /** - * Error + * Error Type + * @description The error type if this queue item errored + */ + error_type?: string | null; + /** + * Error Message * @description The error message if this queue item errored */ - error?: string | null; + error_message?: string | null; + /** + * Error Traceback + * @description The error traceback if this queue item errored + */ + error_traceback?: string | null; /** * Created At * @description When this queue item was created @@ -11901,144 +11918,144 @@ export type components = { */ UIType: "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; InvocationOutputMap: { - calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; + random_range: components["schemas"]["IntegerCollectionOutput"]; + t2i_adapter: components["schemas"]["T2IAdapterOutput"]; + latents: components["schemas"]["LatentsOutput"]; + infill_rgba: components["schemas"]["ImageOutput"]; + conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; + model_identifier: components["schemas"]["ModelIdentifierOutput"]; + img_nsfw: components["schemas"]["ImageOutput"]; + blank_image: components["schemas"]["ImageOutput"]; + ideal_size: components["schemas"]["IdealSizeOutput"]; + integer_math: components["schemas"]["IntegerOutput"]; + add: components["schemas"]["IntegerOutput"]; + img_crop: components["schemas"]["ImageOutput"]; + string_split: components["schemas"]["String2Output"]; + face_mask_detection: components["schemas"]["FaceMaskOutput"]; + color: components["schemas"]["ColorOutput"]; clip_skip: components["schemas"]["CLIPSkipInvocationOutput"]; + img_conv: components["schemas"]["ImageOutput"]; range_of_size: components["schemas"]["IntegerCollectionOutput"]; - lora_loader: components["schemas"]["LoRALoaderOutput"]; - img_mul: components["schemas"]["ImageOutput"]; + calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"]; + face_identifier: components["schemas"]["ImageOutput"]; + zoe_depth_image_processor: components["schemas"]["ImageOutput"]; div: components["schemas"]["IntegerOutput"]; + lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; + sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"]; + integer: components["schemas"]["IntegerOutput"]; crop_latents: components["schemas"]["LatentsOutput"]; - scheduler: components["schemas"]["SchedulerOutput"]; - blank_image: components["schemas"]["ImageOutput"]; invert_tensor_mask: components["schemas"]["MaskOutput"]; - controlnet: components["schemas"]["ControlOutput"]; + save_image: components["schemas"]["ImageOutput"]; + create_denoise_mask: components["schemas"]["DenoiseMaskOutput"]; create_gradient_mask: components["schemas"]["GradientMaskOutput"]; - img_crop: components["schemas"]["ImageOutput"]; - img_chan: components["schemas"]["ImageOutput"]; - iterate: components["schemas"]["IterateInvocationOutput"]; - img_hue_adjust: components["schemas"]["ImageOutput"]; - merge_tiles_to_image: components["schemas"]["ImageOutput"]; + img_channel_offset: components["schemas"]["ImageOutput"]; + cv_inpaint: components["schemas"]["ImageOutput"]; + sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"]; + boolean_collection: components["schemas"]["BooleanCollectionOutput"]; + img_paste: components["schemas"]["ImageOutput"]; + midas_depth_image_processor: components["schemas"]["ImageOutput"]; + tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; + lora_selector: components["schemas"]["LoRASelectorOutput"]; denoise_latents: components["schemas"]["LatentsOutput"]; + mask_combine: components["schemas"]["ImageOutput"]; + hed_image_processor: components["schemas"]["ImageOutput"]; + boolean: components["schemas"]["BooleanOutput"]; + lineart_anime_image_processor: components["schemas"]["ImageOutput"]; string_split_neg: components["schemas"]["StringPosNegOutput"]; metadata_item: components["schemas"]["MetadataItemOutput"]; - face_off: components["schemas"]["FaceOffOutput"]; - zoe_depth_image_processor: components["schemas"]["ImageOutput"]; + img_lerp: components["schemas"]["ImageOutput"]; + mlsd_image_processor: components["schemas"]["ImageOutput"]; prompt_from_file: components["schemas"]["StringCollectionOutput"]; - img_nsfw: components["schemas"]["ImageOutput"]; - infill_lama: components["schemas"]["ImageOutput"]; - vae_loader: components["schemas"]["VAEOutput"]; + infill_tile: components["schemas"]["ImageOutput"]; + iterate: components["schemas"]["IterateInvocationOutput"]; + calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; + lora_loader: components["schemas"]["LoRALoaderOutput"]; + calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; + face_off: components["schemas"]["FaceOffOutput"]; noise: components["schemas"]["NoiseOutput"]; - midas_depth_image_processor: components["schemas"]["ImageOutput"]; - string: components["schemas"]["StringOutput"]; - img_conv: components["schemas"]["ImageOutput"]; - mlsd_image_processor: components["schemas"]["ImageOutput"]; + mask_edge: components["schemas"]["ImageOutput"]; + normalbae_image_processor: components["schemas"]["ImageOutput"]; + tile_image_processor: components["schemas"]["ImageOutput"]; + canvas_paste_back: components["schemas"]["ImageOutput"]; + color_map_image_processor: components["schemas"]["ImageOutput"]; + img_channel_multiply: components["schemas"]["ImageOutput"]; core_metadata: components["schemas"]["MetadataOutput"]; float_math: components["schemas"]["FloatOutput"]; - hed_image_processor: components["schemas"]["ImageOutput"]; - lineart_anime_image_processor: components["schemas"]["ImageOutput"]; - main_model_loader: components["schemas"]["ModelLoaderOutput"]; - infill_cv2: components["schemas"]["ImageOutput"]; - image: components["schemas"]["ImageOutput"]; - normalbae_image_processor: components["schemas"]["ImageOutput"]; + string: components["schemas"]["StringOutput"]; + lineart_image_processor: components["schemas"]["ImageOutput"]; + img_watermark: components["schemas"]["ImageOutput"]; + color_correct: components["schemas"]["ImageOutput"]; rand_int: components["schemas"]["IntegerOutput"]; - image_collection: components["schemas"]["ImageCollectionOutput"]; - step_param_easing: components["schemas"]["FloatCollectionOutput"]; - infill_patchmatch: components["schemas"]["ImageOutput"]; + merge_tiles_to_image: components["schemas"]["ImageOutput"]; + collect: components["schemas"]["CollectInvocationOutput"]; + scheduler: components["schemas"]["SchedulerOutput"]; + mask_from_id: components["schemas"]["ImageOutput"]; sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; - string_collection: components["schemas"]["StringCollectionOutput"]; - img_paste: components["schemas"]["ImageOutput"]; - infill_rgba: components["schemas"]["ImageOutput"]; - integer_collection: components["schemas"]["IntegerCollectionOutput"]; float_to_int: components["schemas"]["IntegerOutput"]; - tile_image_processor: components["schemas"]["ImageOutput"]; - mask_combine: components["schemas"]["ImageOutput"]; + image: components["schemas"]["ImageOutput"]; + unsharp_mask: components["schemas"]["ImageOutput"]; merge_metadata: components["schemas"]["MetadataOutput"]; + image_collection: components["schemas"]["ImageCollectionOutput"]; rectangle_mask: components["schemas"]["MaskOutput"]; - color_map_image_processor: components["schemas"]["ImageOutput"]; - img_lerp: components["schemas"]["ImageOutput"]; - mask_edge: components["schemas"]["ImageOutput"]; - ip_adapter: components["schemas"]["IPAdapterOutput"]; - lineart_image_processor: components["schemas"]["ImageOutput"]; - seamless: components["schemas"]["SeamlessModeOutput"]; - img_channel_offset: components["schemas"]["ImageOutput"]; - sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"]; - range: components["schemas"]["IntegerCollectionOutput"]; + img_pad_crop: components["schemas"]["ImageOutput"]; + heuristic_resize: components["schemas"]["ImageOutput"]; + string_collection: components["schemas"]["StringCollectionOutput"]; + leres_image_processor: components["schemas"]["ImageOutput"]; + vae_loader: components["schemas"]["VAEOutput"]; + tomask: components["schemas"]["ImageOutput"]; + mediapipe_face_processor: components["schemas"]["ImageOutput"]; lresize: components["schemas"]["LatentsOutput"]; - freeu: components["schemas"]["UNetOutput"]; - string_join: components["schemas"]["StringOutput"]; - compel: components["schemas"]["ConditioningOutput"]; - collect: components["schemas"]["CollectInvocationOutput"]; - img_watermark: components["schemas"]["ImageOutput"]; - float_range: components["schemas"]["FloatCollectionOutput"]; + range: components["schemas"]["IntegerCollectionOutput"]; sdxl_compel_prompt: components["schemas"]["ConditioningOutput"]; - i2l: components["schemas"]["LatentsOutput"]; + sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"]; sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"]; + float_range: components["schemas"]["FloatCollectionOutput"]; float: components["schemas"]["FloatOutput"]; - dynamic_prompt: components["schemas"]["StringCollectionOutput"]; - save_image: components["schemas"]["ImageOutput"]; - heuristic_resize: components["schemas"]["ImageOutput"]; - lblend: components["schemas"]["LatentsOutput"]; - tomask: components["schemas"]["ImageOutput"]; - leres_image_processor: components["schemas"]["ImageOutput"]; - lscale: components["schemas"]["LatentsOutput"]; - conditioning: components["schemas"]["ConditioningOutput"]; - mediapipe_face_processor: components["schemas"]["ImageOutput"]; - esrgan: components["schemas"]["ImageOutput"]; - img_pad_crop: components["schemas"]["ImageOutput"]; - content_shuffle_image_processor: components["schemas"]["ImageOutput"]; - color_correct: components["schemas"]["ImageOutput"]; - unsharp_mask: components["schemas"]["ImageOutput"]; - infill_tile: components["schemas"]["ImageOutput"]; + string_join: components["schemas"]["StringOutput"]; canny_image_processor: components["schemas"]["ImageOutput"]; + main_model_loader: components["schemas"]["ModelLoaderOutput"]; + compel: components["schemas"]["ConditioningOutput"]; + i2l: components["schemas"]["LatentsOutput"]; show_image: components["schemas"]["ImageOutput"]; + img_scale: components["schemas"]["ImageOutput"]; + content_shuffle_image_processor: components["schemas"]["ImageOutput"]; + integer_collection: components["schemas"]["IntegerCollectionOutput"]; + ip_adapter: components["schemas"]["IPAdapterOutput"]; + dynamic_prompt: components["schemas"]["StringCollectionOutput"]; + lscale: components["schemas"]["LatentsOutput"]; pidi_image_processor: components["schemas"]["ImageOutput"]; - pair_tile_image: components["schemas"]["PairTileImageOutput"]; segment_anything_processor: components["schemas"]["ImageOutput"]; - rand_float: components["schemas"]["FloatOutput"]; - canvas_paste_back: components["schemas"]["ImageOutput"]; + lblend: components["schemas"]["LatentsOutput"]; + img_resize: components["schemas"]["ImageOutput"]; depth_anything_image_processor: components["schemas"]["ImageOutput"]; - img_channel_multiply: components["schemas"]["ImageOutput"]; + img_blur: components["schemas"]["ImageOutput"]; + esrgan: components["schemas"]["ImageOutput"]; + seamless: components["schemas"]["SeamlessModeOutput"]; + dw_openpose_image_processor: components["schemas"]["ImageOutput"]; + freeu: components["schemas"]["UNetOutput"]; + img_ilerp: components["schemas"]["ImageOutput"]; + rand_float: components["schemas"]["FloatOutput"]; + conditioning: components["schemas"]["ConditioningOutput"]; + step_param_easing: components["schemas"]["FloatCollectionOutput"]; + controlnet: components["schemas"]["ControlOutput"]; metadata: components["schemas"]["MetadataOutput"]; string_replace: components["schemas"]["StringOutput"]; image_mask_to_tensor: components["schemas"]["MaskOutput"]; + infill_lama: components["schemas"]["ImageOutput"]; mul: components["schemas"]["IntegerOutput"]; - img_scale: components["schemas"]["ImageOutput"]; - model_identifier: components["schemas"]["ModelIdentifierOutput"]; + img_hue_adjust: components["schemas"]["ImageOutput"]; + latents_collection: components["schemas"]["LatentsCollectionOutput"]; alpha_mask_to_tensor: components["schemas"]["MaskOutput"]; - latents: components["schemas"]["LatentsOutput"]; - dw_openpose_image_processor: components["schemas"]["ImageOutput"]; - mask_from_id: components["schemas"]["ImageOutput"]; - conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; round_float: components["schemas"]["FloatOutput"]; - face_mask_detection: components["schemas"]["FaceMaskOutput"]; - calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"]; - img_resize: components["schemas"]["ImageOutput"]; + float_collection: components["schemas"]["FloatCollectionOutput"]; + img_mul: components["schemas"]["ImageOutput"]; l2i: components["schemas"]["ImageOutput"]; - color: components["schemas"]["ColorOutput"]; - lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; - sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"]; string_join_three: components["schemas"]["StringOutput"]; + infill_cv2: components["schemas"]["ImageOutput"]; sub: components["schemas"]["IntegerOutput"]; - img_blur: components["schemas"]["ImageOutput"]; - float_collection: components["schemas"]["FloatCollectionOutput"]; - integer: components["schemas"]["IntegerOutput"]; - face_identifier: components["schemas"]["ImageOutput"]; - latents_collection: components["schemas"]["LatentsCollectionOutput"]; - cv_inpaint: components["schemas"]["ImageOutput"]; - t2i_adapter: components["schemas"]["T2IAdapterOutput"]; - create_denoise_mask: components["schemas"]["DenoiseMaskOutput"]; - random_range: components["schemas"]["IntegerCollectionOutput"]; - sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"]; - ideal_size: components["schemas"]["IdealSizeOutput"]; - tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; - img_ilerp: components["schemas"]["ImageOutput"]; - integer_math: components["schemas"]["IntegerOutput"]; - add: components["schemas"]["IntegerOutput"]; - boolean: components["schemas"]["BooleanOutput"]; - string_split: components["schemas"]["String2Output"]; - lora_selector: components["schemas"]["LoRASelectorOutput"]; - boolean_collection: components["schemas"]["BooleanCollectionOutput"]; - calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; + img_chan: components["schemas"]["ImageOutput"]; + pair_tile_image: components["schemas"]["PairTileImageOutput"]; + infill_patchmatch: components["schemas"]["ImageOutput"]; }; }; responses: never; From 6063487b20898b5fe65f11124fb5b4a327155a6e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 23 May 2024 15:21:01 +1000 Subject: [PATCH 22/34] feat(ui): handle enriched events --- .../listeners/socketio/socketInvocationError.ts | 12 ++++++++++-- .../web/src/features/nodes/types/invocation.ts | 7 ++++++- .../queue/components/QueueList/QueueItemDetail.tsx | 4 ++-- invokeai/frontend/web/src/services/events/types.ts | 7 +++++-- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts index 07cfa08e918..8cf79462c98 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts @@ -39,13 +39,21 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe actionCreator: socketInvocationError, effect: (action, { getState }) => { log.error(action.payload, `Invocation error (${action.payload.data.node.type})`); - const { source_node_id, error_type, graph_execution_state_id } = action.payload.data; + const { source_node_id, error_type, error_message, error_traceback, graph_execution_state_id } = + action.payload.data; const nes = deepClone($nodeExecutionStates.get()[source_node_id]); if (nes) { nes.status = zNodeStatus.enum.FAILED; - nes.error = action.payload.data.error; nes.progress = null; nes.progressImage = null; + + if (error_type && error_message && error_traceback) { + nes.error = { + error_type, + error_message, + error_traceback, + }; + } upsertExecutionState(nes.nodeId, nes); } diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 66a3db62bf4..0a7149bd6bb 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -70,13 +70,18 @@ export const isInvocationNodeData = (node?: AnyNodeData | null): node is Invocat // #region NodeExecutionState export const zNodeStatus = z.enum(['PENDING', 'IN_PROGRESS', 'COMPLETED', 'FAILED']); +const zNodeError = z.object({ + error_type: z.string(), + error_message: z.string(), + error_traceback: z.string(), +}); const zNodeExecutionState = z.object({ nodeId: z.string().trim().min(1), status: zNodeStatus, progress: z.number().nullable(), progressImage: zProgressImage.nullable(), - error: z.string().nullable(), outputs: z.array(z.any()), + error: zNodeError.nullable(), }); export type NodeExecutionState = z.infer; // #endregion diff --git a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemDetail.tsx b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemDetail.tsx index b719ae0a92b..e3f2436aca9 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemDetail.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemDetail.tsx @@ -76,7 +76,7 @@ const QueueItemComponent = ({ queueItemDTO }: Props) => { - {queueItem?.error && ( + {(queueItem?.error_traceback || queueItem?.error_message) && ( { {t('common.error')} -
{queueItem.error}
+
{queueItem?.error_traceback ?? queueItem?.error_message}
)} diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index 161a85b8f6a..e1dea1563b6 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -116,7 +116,8 @@ export type InvocationErrorEvent = { node: BaseNode; source_node_id: string; error_type: string; - error: string; + error_message: string; + error_traceback: string; }; /** @@ -187,7 +188,9 @@ export type QueueItemStatusChangedEvent = { batch_id: string; session_id: string; status: components['schemas']['SessionQueueItemDTO']['status']; - error: string | undefined; + error_type?: string | null; + error_message?: string | null; + error_traceback?: string | null; created_at: string; updated_at: string; started_at: string | undefined; From a98ddedb9576a4393fc2d34bfd5aff343ff55eb3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 10:20:20 +1000 Subject: [PATCH 23/34] docs(processor): update docstrings, comments --- .../session_processor_base.py | 81 ++++++++++++++++--- .../session_processor_default.py | 36 ++++++--- 2 files changed, 97 insertions(+), 20 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_base.py b/invokeai/app/services/session_processor/session_processor_base.py index 1436627a9ea..15611bb5f87 100644 --- a/invokeai/app/services/session_processor/session_processor_base.py +++ b/invokeai/app/services/session_processor/session_processor_base.py @@ -16,17 +16,33 @@ class SessionRunnerBase(ABC): @abstractmethod def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None: - """Starts the session runner""" + """Starts the session runner. + + Args: + services: The invocation services. + cancel_event: The cancel event. + profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session + stats will be still be recorded and logged when profiling is disabled. + """ pass @abstractmethod def run(self, queue_item: SessionQueueItem) -> None: - """Runs the session""" + """Runs a session. + + Args: + queue_item: The session to run. + """ pass @abstractmethod def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None: - """Runs an already prepared node on the session""" + """Run a single node in the graph. + + Args: + invocation: The invocation to run. + queue_item: The session queue item. + """ pass @@ -56,13 +72,25 @@ def get_status(self) -> SessionProcessorStatus: class OnBeforeRunNode(Protocol): - def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> bool: ... + def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None: + """Callback to run before executing a node. + + Args: + invocation: The invocation that will be executed. + queue_item: The session queue item. + """ + ... class OnAfterRunNode(Protocol): - def __call__( - self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput - ) -> bool: ... + def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None: + """Callback to run before executing a node. + + Args: + invocation: The invocation that was executed. + queue_item: The session queue item. + """ + ... class OnNodeError(Protocol): @@ -73,15 +101,37 @@ def __call__( error_type: str, error_message: str, error_traceback: str, - ) -> bool: ... + ) -> None: + """Callback to run when a node has an error. + + Args: + invocation: The invocation that errored. + queue_item: The session queue item. + error_type: The type of error, e.g. "ValueError". + error_message: The error message, e.g. "Invalid value". + error_traceback: The stringified error traceback. + """ + ... class OnBeforeRunSession(Protocol): - def __call__(self, queue_item: SessionQueueItem) -> bool: ... + def __call__(self, queue_item: SessionQueueItem) -> None: + """Callback to run before executing a session. + + Args: + queue_item: The session queue item. + """ + ... class OnAfterRunSession(Protocol): - def __call__(self, queue_item: SessionQueueItem) -> bool: ... + def __call__(self, queue_item: SessionQueueItem) -> None: + """Callback to run after executing a session. + + Args: + queue_item: The session queue item. + """ + ... class OnNonFatalProcessorError(Protocol): @@ -91,4 +141,13 @@ def __call__( error_type: str, error_message: str, error_traceback: str, - ) -> bool: ... + ) -> None: + """Callback to run when a non-fatal error occurs in the processor. + + Args: + queue_item: The session queue item, if one was being executed when the error occurred. + error_type: The type of error, e.g. "ValueError". + error_message: The error message, e.g. "Invalid value". + error_traceback: The stringified error traceback. + """ + ... diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 49277a105d7..eec835af87f 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -30,7 +30,7 @@ class DefaultSessionRunner(SessionRunnerBase): - """Processes a single session's invocations""" + """Processes a single session's invocations.""" def __init__( self, @@ -40,6 +40,15 @@ def __init__( on_node_error_callbacks: Optional[list[OnNodeError]] = None, on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None, ): + """ + Args: + on_before_run_session_callbacks: Callbacks to run before the session starts. + on_before_run_node_callbacks: Callbacks to run before each node starts. + on_after_run_node_callbacks: Callbacks to run after each node completes. + on_node_error_callbacks: Callbacks to run when a node errors. + on_after_run_session_callbacks: Callbacks to run after the session completes. + """ + self._on_before_run_session_callbacks = on_before_run_session_callbacks or [] self._on_before_run_node_callbacks = on_before_run_node_callbacks or [] self._on_after_run_node_callbacks = on_after_run_node_callbacks or [] @@ -47,14 +56,12 @@ def __init__( self._on_after_run_session_callbacks = on_after_run_session_callbacks or [] def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): - """Start the session runner""" self._services = services self._cancel_event = cancel_event self._profiler = profiler def run(self, queue_item: SessionQueueItem): - """Run the graph""" - # Exceptions raised outside `run_node` are handled by the processor. + # Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here. self._on_before_run_session(queue_item=queue_item) @@ -78,14 +85,16 @@ def run(self, queue_item: SessionQueueItem): if invocation is None or self._cancel_event.is_set(): break + self.run_node(invocation, queue_item) + + # The session is complete if all invocations have been run or there is an error on the session. if queue_item.session.is_complete() or self._cancel_event.is_set(): break self._on_after_run_session(queue_item=queue_item) def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): - """Run a single node in the graph""" try: # Any unhandled exception in this scope is an invocation error & will fail the graph with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id): @@ -110,7 +119,7 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): self._on_after_run_node(invocation, queue_item, output) except KeyboardInterrupt: - # TODO(MM2): Create an event for this + # TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here? pass except CanceledException: # When the user cancels the graph, we first set the cancel event. The event is checked @@ -137,6 +146,8 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): ) def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: + """Run before a session is executed""" + # If profiling is enabled, start the profiler if self._profiler is not None: self._profiler.start(profile_id=queue_item.session_id) @@ -145,6 +156,8 @@ def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: callback(queue_item=queue_item) def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: + """Run after a session is executed""" + # If we are profiling, stop the profiler and dump the profile & stats if self._profiler is not None: profile_path = self._profiler.stop() @@ -156,7 +169,8 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: # Update the queue item with the completed session self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) - # Send complete event + # TODO(psyche): This feels jumbled - we should review separation of concerns here. + # Send complete event. The events service will receive this and update the queue item's status. self._services.events.emit_graph_execution_complete( queue_batch_id=queue_item.batch_id, queue_item_id=queue_item.item_id, @@ -175,6 +189,7 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run before a node is executed""" + # Send starting event self._services.events.emit_invocation_started( queue_batch_id=queue_item.batch_id, @@ -192,6 +207,7 @@ def _on_after_run_node( self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput ): """Run after a node is executed""" + # Send complete event on successful runs self._services.events.emit_invocation_complete( queue_batch_id=queue_item.batch_id, @@ -214,6 +230,8 @@ def _on_node_error( error_message: str, error_traceback: str, ): + """Run when a node errors""" + # Node errors do not get the full traceback. Only the queue item gets the full traceback. node_error = f"{error_type}: {error_message}" queue_item.session.set_node_error(invocation.id, node_error) @@ -356,8 +374,8 @@ def _process( resume_event: ThreadEvent, cancel_event: ThreadEvent, ): - # Outermost processor try block; any unhandled exception is a fatal processor error try: + # Any unhandled exception in this block is a fatal processor error and will stop the processor. self._thread_semaphore.acquire() stop_event.clear() resume_event.set() @@ -365,8 +383,8 @@ def _process( while not stop_event.is_set(): poll_now_event.clear() - # Middle processor try block; any unhandled exception is a non-fatal processor error try: + # Any unhandled exception in this block is a nonfatal processor error and will be handled. # If we are paused, wait for resume event resume_event.wait() From 7d1844eaf2838467a79237c4b7f9c2dfe2dfba42 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 10:21:01 +1000 Subject: [PATCH 24/34] chore: ruff --- invokeai/app/services/session_queue/session_queue_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index 8b21998f193..fc45183aef4 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -79,7 +79,9 @@ def cancel_queue_item(self, item_id: int) -> SessionQueueItem: pass @abstractmethod - def fail_queue_item(self, item_id: int, error_type: str, error_message: str, error_traceback: str) -> SessionQueueItem: + def fail_queue_item( + self, item_id: int, error_type: str, error_message: str, error_traceback: str + ) -> SessionQueueItem: """Fails a session queue item""" pass From c88de180e735a9db31cef26c9a0dc9f421dcfa1f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 10:48:33 +1000 Subject: [PATCH 25/34] tidy(queue): delete unused `delete_queue_item` method --- .../session_queue/session_queue_sqlite.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index dfd00a78094..9401eabecf2 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -352,26 +352,6 @@ def is_full(self, queue_id: str) -> IsFullResult: self.__lock.release() return IsFullResult(is_full=is_full) - def delete_queue_item(self, item_id: int) -> SessionQueueItem: - queue_item = self.get_queue_item(item_id=item_id) - try: - self.__lock.acquire() - self.__cursor.execute( - """--sql - DELETE FROM session_queue - WHERE - item_id = ? - """, - (item_id,), - ) - self.__conn.commit() - except Exception: - self.__conn.rollback() - raise - finally: - self.__lock.release() - return queue_item - def clear(self, queue_id: str) -> ClearResult: try: self.__lock.acquire() From 169b75b2b7f0ffe71d856ed5905bbd9caf52538b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 11:23:26 +1000 Subject: [PATCH 26/34] tidy(processor): remove test callbacks --- invokeai/app/api/dependencies.py | 35 +------------------------------- 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index b3c2acfb947..bb62158c6b0 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -104,40 +104,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger names = SimpleNameService() performance_statistics = InvocationStatsService() - def on_before_run_session(queue_item): - print("BEFORE RUN SESSION", queue_item.item_id) - return True - - def on_before_run_node(invocation, queue_item): - print("BEFORE RUN NODE", invocation.id) - return True - - def on_after_run_node(invocation, queue_item, output): - print("AFTER RUN NODE", invocation.id) - return True - - def on_node_error(invocation, queue_item, error_type, error_message, error_traceback): - print("NODE ERROR", invocation.id) - return True - - def on_after_run_session(queue_item): - print("AFTER RUN SESSION", queue_item.item_id) - return True - - def on_non_fatal_processor_error(queue_item, error_type, error_message, error_traceback): - print("NON FATAL PROCESSOR ERROR", error_message) - return True - - session_processor = DefaultSessionProcessor( - DefaultSessionRunner( - on_before_run_session_callbacks=[on_before_run_session], - on_before_run_node_callbacks=[on_before_run_node], - on_after_run_node_callbacks=[on_after_run_node], - on_node_error_callbacks=[on_node_error], - on_after_run_session_callbacks=[on_after_run_session], - ), - on_non_fatal_processor_error_callbacks=[on_non_fatal_processor_error], - ) + session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner()) session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() workflow_records = SqliteWorkflowRecordsStorage(db=db) From 350feeed566f4635ee7fad7fa7a856ac53df88a7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 11:26:34 +1000 Subject: [PATCH 27/34] fix(processor): fix race condition related to clearing the queue --- .../session_processor/session_processor_default.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index eec835af87f..c87108e5a08 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -19,7 +19,7 @@ OnNonFatalProcessorError, ) from invokeai.app.services.session_processor.session_processor_common import CanceledException -from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem +from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError from invokeai.app.services.shared.graph import NodeInputError from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler @@ -166,8 +166,11 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: graph_execution_state_id=queue_item.session.id, output_path=stats_path ) - # Update the queue item with the completed session - self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) + try: + # Update the queue item with the completed session. If the queue item has been removed from the queue, + # we'll get a SessionQueueItemNotFoundError and we can ignore it. This can happen if the queue is cleared + # while the session is running. + queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) # TODO(psyche): This feels jumbled - we should review separation of concerns here. # Send complete event. The events service will receive this and update the queue item's status. @@ -186,6 +189,8 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: for callback in self._on_after_run_session_callbacks: callback(queue_item=queue_item) + except SessionQueueItemNotFoundError: + pass def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run before a node is executed""" @@ -349,6 +354,7 @@ async def _on_queue_event(self, event: FastAPIEvent) -> None: "failed", "canceled", ]: + self._cancel_event.set() self._poll_now() def resume(self) -> SessionProcessorStatus: From fb93e686b2f35219fa87bb89a99716a6c5e9fce6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 11:28:55 +1000 Subject: [PATCH 28/34] feat(processor): add debug log stmts to session running callbacks --- .../session_processor_default.py | 50 +++++++++++++------ 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index c87108e5a08..8cb92168216 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -148,6 +148,10 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: """Run before a session is executed""" + self._services.logger.debug( + f"On before run session: queue item {queue_item.item_id}, session {queue_item.session_id}" + ) + # If profiling is enabled, start the profiler if self._profiler is not None: self._profiler.start(profile_id=queue_item.session_id) @@ -158,6 +162,10 @@ def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: """Run after a session is executed""" + self._services.logger.debug( + f"On after run session: queue item {queue_item.item_id}, session {queue_item.session_id}" + ) + # If we are profiling, stop the profiler and dump the profile & stats if self._profiler is not None: profile_path = self._profiler.stop() @@ -172,29 +180,33 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: # while the session is running. queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) - # TODO(psyche): This feels jumbled - we should review separation of concerns here. - # Send complete event. The events service will receive this and update the queue item's status. - self._services.events.emit_graph_execution_complete( - queue_batch_id=queue_item.batch_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session.id, - ) + # TODO(psyche): This feels jumbled - we should review separation of concerns here. + # Send complete event. The events service will receive this and update the queue item's status. + self._services.events.emit_graph_execution_complete( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + ) - # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor - # we don't care about that - suppress the error. - with suppress(GESStatsNotFoundError): - self._services.performance_statistics.log_stats(queue_item.session.id) - self._services.performance_statistics.reset_stats() + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor + # we don't care about that - suppress the error. + with suppress(GESStatsNotFoundError): + self._services.performance_statistics.log_stats(queue_item.session.id) + self._services.performance_statistics.reset_stats() - for callback in self._on_after_run_session_callbacks: - callback(queue_item=queue_item) + for callback in self._on_after_run_session_callbacks: + callback(queue_item=queue_item) except SessionQueueItemNotFoundError: pass def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run before a node is executed""" + self._services.logger.debug( + f"On before run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})" + ) + # Send starting event self._services.events.emit_invocation_started( queue_batch_id=queue_item.batch_id, @@ -213,6 +225,10 @@ def _on_after_run_node( ): """Run after a node is executed""" + self._services.logger.debug( + f"On after run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})" + ) + # Send complete event on successful runs self._services.events.emit_invocation_complete( queue_batch_id=queue_item.batch_id, @@ -237,6 +253,10 @@ def _on_node_error( ): """Run when a node errors""" + self._services.logger.debug( + f"On node error: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})" + ) + # Node errors do not get the full traceback. Only the queue item gets the full traceback. node_error = f"{error_type}: {error_message}" queue_item.session.set_node_error(invocation.id, node_error) From 0758e9cb9b3290e5af3831596bb512cea476d010 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 12:01:02 +1000 Subject: [PATCH 29/34] fix(ui): race condition with progress There's a race condition where a canceled session may emit a progress event or two after it's been canceled, and the progress image isn't cleared out. To resolve this, the system slice tracks canceled session ids. When a progress event comes in, we check the cancellations and skip setting the progress if canceled. --- .../web/src/features/system/store/systemSlice.ts | 11 ++++++++++- .../frontend/web/src/features/system/store/types.ts | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 65903460ed7..4d87b2c4ec9 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -31,6 +31,7 @@ const initialSystemState: SystemState = { shouldUseWatermarker: false, shouldEnableInformationalPopovers: false, status: 'DISCONNECTED', + cancellations: [], }; export const systemSlice = createSlice({ @@ -88,6 +89,7 @@ export const systemSlice = createSlice({ * Invocation Started */ builder.addCase(socketInvocationStarted, (state) => { + state.cancellations = []; state.denoiseProgress = null; state.status = 'PROCESSING'; }); @@ -105,6 +107,12 @@ export const systemSlice = createSlice({ queue_batch_id: batch_id, } = action.payload.data; + if (state.cancellations.includes(session_id)) { + // Do not update the progress if this session has been cancelled. This prevents a race condition where we get a + // progress update after the session has been cancelled. + return; + } + state.denoiseProgress = { step, total_steps, @@ -146,6 +154,7 @@ export const systemSlice = createSlice({ if (['completed', 'canceled', 'failed'].includes(action.payload.data.queue_item.status)) { state.status = 'CONNECTED'; state.denoiseProgress = null; + state.cancellations.push(action.payload.data.queue_item.session_id); } }); }, @@ -177,5 +186,5 @@ export const systemPersistConfig: PersistConfig = { name: systemSlice.name, initialState: initialSystemState, migrate: migrateSystemState, - persistDenylist: ['isConnected', 'denoiseProgress', 'status'], + persistDenylist: ['isConnected', 'denoiseProgress', 'status', 'cancellations'], }; diff --git a/invokeai/frontend/web/src/features/system/store/types.ts b/invokeai/frontend/web/src/features/system/store/types.ts index d8bc8cd08a9..d32e347870c 100644 --- a/invokeai/frontend/web/src/features/system/store/types.ts +++ b/invokeai/frontend/web/src/features/system/store/types.ts @@ -55,4 +55,5 @@ export interface SystemState { shouldUseWatermarker: boolean; status: SystemStatus; shouldEnableInformationalPopovers: boolean; + cancellations: string[] } From 08a42c3c03aadbc1bfd79ebc507d39e9c08f0c2c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 12:14:48 +1000 Subject: [PATCH 30/34] tidy(ui): remove extraneous condition in socketInvocationError --- .../listeners/socketio/socketInvocationError.ts | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts index 8cf79462c98..2d34ffdde6f 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts @@ -47,13 +47,11 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe nes.progress = null; nes.progressImage = null; - if (error_type && error_message && error_traceback) { - nes.error = { - error_type, - error_message, - error_traceback, - }; - } + nes.error = { + error_type, + error_message, + error_traceback, + }; upsertExecutionState(nes.nodeId, nes); } From dc78a0e69924aa6ba397e2c94929b72dc28cb7f3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 12:15:51 +1000 Subject: [PATCH 31/34] fix(ui): correctly fallback to error message when traceback is empty string --- .../src/features/queue/components/QueueList/QueueItemDetail.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemDetail.tsx b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemDetail.tsx index e3f2436aca9..d5b1e7dc59a 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemDetail.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemDetail.tsx @@ -89,7 +89,7 @@ const QueueItemComponent = ({ queueItemDTO }: Props) => { {t('common.error')} -
{queueItem?.error_traceback ?? queueItem?.error_message}
+
{queueItem?.error_traceback || queueItem?.error_message}
)} From 65e85d1162a228c92120f218d930f151884c1155 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 12:32:29 +1000 Subject: [PATCH 32/34] tidy: remove unnecessary whitespace changes --- invokeai/app/api/dependencies.py | 1 - .../listeners/socketio/socketInvocationError.ts | 1 - invokeai/frontend/web/src/features/system/store/types.ts | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index bb62158c6b0..aa16974a8f4 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -103,7 +103,6 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger ) names = SimpleNameService() performance_statistics = InvocationStatsService() - session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner()) session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts index 2d34ffdde6f..b09b57bd0c1 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts @@ -46,7 +46,6 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe nes.status = zNodeStatus.enum.FAILED; nes.progress = null; nes.progressImage = null; - nes.error = { error_type, error_message, diff --git a/invokeai/frontend/web/src/features/system/store/types.ts b/invokeai/frontend/web/src/features/system/store/types.ts index d32e347870c..e0fa5634a2b 100644 --- a/invokeai/frontend/web/src/features/system/store/types.ts +++ b/invokeai/frontend/web/src/features/system/store/types.ts @@ -55,5 +55,5 @@ export interface SystemState { shouldUseWatermarker: boolean; status: SystemStatus; shouldEnableInformationalPopovers: boolean; - cancellations: string[] + cancellations: string[]; } From 5edc825331644eedcbb87fab83388cf89f088e97 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 18:29:41 +1000 Subject: [PATCH 33/34] fix(processor): race condition that could result in node errors not getting reported I had set the cancel event at some point during troubleshooting an unrelated issue. It seemed logical that it should be set there, and didn't seem to break anything. However, this is not correct. The cancel event should not be set in response to a queue status change event. Doing so can cause a race condition when nodes are executed very quickly. It's possible that a previously-executed session's queue item status change event is handled after the next session starts executing. The cancel event is set and the session runner sees it aborting the session run early. In hindsight, it doesn't make sense to set the cancel event here either. It should be set in response to user action, e.g. the user cancelled the session or cleared the queue (which implicitly cancels the current session). These events actually trigger the queue item status changed event, so if we set the cancel event here, we'd be setting it twice per cancellation. --- .../app/services/session_processor/session_processor_default.py | 1 - 1 file changed, 1 deletion(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 8cb92168216..2207e71176f 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -374,7 +374,6 @@ async def _on_queue_event(self, event: FastAPIEvent) -> None: "failed", "canceled", ]: - self._cancel_event.set() self._poll_now() def resume(self) -> SessionProcessorStatus: From 5ee9ff722eb1e6f080ca11fd7158df35ecc9510f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 19:03:35 +1000 Subject: [PATCH 34/34] feat(ui): toast on queue item errors, improved error descriptions Show error toasts on queue item error events instead of invocation error events. This allows errors that occurred outside node execution to be surfaced to the user. The error description component is updated to show the new error message if available. Commercial handling is retained, but local now uses the same component to display the error message itself. --- .../socketio/socketInvocationError.ts | 45 +------------- ...ed.ts => socketQueueItemStatusChanged.tsx} | 26 +++++++- .../features/toast/ErrorToastDescription.tsx | 60 +++++++++++++++++++ .../toast/ToastWithSessionRefDescription.tsx | 30 ---------- 4 files changed, 86 insertions(+), 75 deletions(-) rename invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/{socketQueueItemStatusChanged.ts => socketQueueItemStatusChanged.tsx} (73%) create mode 100644 invokeai/frontend/web/src/features/toast/ErrorToastDescription.tsx delete mode 100644 invokeai/frontend/web/src/features/toast/ToastWithSessionRefDescription.tsx diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts index b09b57bd0c1..df1759f3a9a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts @@ -3,44 +3,16 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { deepClone } from 'common/util/deepClone'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; -import { toast } from 'features/toast/toast'; -import ToastWithSessionRefDescription from 'features/toast/ToastWithSessionRefDescription'; -import { t } from 'i18next'; -import { startCase } from 'lodash-es'; import { socketInvocationError } from 'services/events/actions'; const log = logger('socketio'); -const getTitle = (errorType: string) => { - if (errorType === 'OutOfMemoryError') { - return t('toast.outOfMemoryError'); - } - return t('toast.serverError'); -}; - -const getDescription = (errorType: string, sessionId: string, isLocal?: boolean) => { - if (!isLocal) { - if (errorType === 'OutOfMemoryError') { - return ToastWithSessionRefDescription({ - message: t('toast.outOfMemoryDescription'), - sessionId, - }); - } - return ToastWithSessionRefDescription({ - message: errorType, - sessionId, - }); - } - return errorType; -}; - export const addInvocationErrorEventListener = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: socketInvocationError, - effect: (action, { getState }) => { + effect: (action) => { log.error(action.payload, `Invocation error (${action.payload.data.node.type})`); - const { source_node_id, error_type, error_message, error_traceback, graph_execution_state_id } = - action.payload.data; + const { source_node_id, error_type, error_message, error_traceback } = action.payload.data; const nes = deepClone($nodeExecutionStates.get()[source_node_id]); if (nes) { nes.status = zNodeStatus.enum.FAILED; @@ -53,19 +25,6 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe }; upsertExecutionState(nes.nodeId, nes); } - - const errorType = startCase(error_type); - const sessionId = graph_execution_state_id; - const { isLocal } = getState().config; - - toast({ - id: `INVOCATION_ERROR_${errorType}`, - title: getTitle(errorType), - status: 'error', - duration: null, - description: getDescription(errorType, sessionId, isLocal), - updateDescription: isLocal ? true : false, - }); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.tsx similarity index 73% rename from invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts rename to invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.tsx index 3b274b28891..b72401d9155 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.tsx @@ -3,6 +3,8 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { deepClone } from 'common/util/deepClone'; import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; +import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription'; +import { toast } from 'features/toast/toast'; import { forEach } from 'lodash-es'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; import { socketQueueItemStatusChanged } from 'services/events/actions'; @@ -12,7 +14,7 @@ const log = logger('socketio'); export const addSocketQueueItemStatusChangedEventListener = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: socketQueueItemStatusChanged, - effect: async (action, { dispatch }) => { + effect: async (action, { dispatch, getState }) => { // we've got new status for the queue item, batch and queue const { queue_item, batch_status, queue_status } = action.payload.data; @@ -54,7 +56,7 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening: ]) ); - if (['in_progress'].includes(action.payload.data.queue_item.status)) { + if (queue_item.status === 'in_progress') { forEach($nodeExecutionStates.get(), (nes) => { if (!nes) { return; @@ -67,6 +69,26 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening: clone.outputs = []; $nodeExecutionStates.setKey(clone.nodeId, clone); }); + } else if (queue_item.status === 'failed' && queue_item.error_type) { + const { error_type, error_message, session_id } = queue_item; + const isLocal = getState().config.isLocal ?? true; + const sessionId = session_id; + + toast({ + id: `INVOCATION_ERROR_${error_type}`, + title: getTitleFromErrorType(error_type), + status: 'error', + duration: null, + description: ( + + ), + updateDescription: isLocal ? true : false, + }); } }, }); diff --git a/invokeai/frontend/web/src/features/toast/ErrorToastDescription.tsx b/invokeai/frontend/web/src/features/toast/ErrorToastDescription.tsx new file mode 100644 index 00000000000..b9729c15103 --- /dev/null +++ b/invokeai/frontend/web/src/features/toast/ErrorToastDescription.tsx @@ -0,0 +1,60 @@ +import { Flex, IconButton, Text } from '@invoke-ai/ui-library'; +import { t } from 'i18next'; +import { upperFirst } from 'lodash-es'; +import { useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiCopyBold } from 'react-icons/pi'; + +function onCopy(sessionId: string) { + navigator.clipboard.writeText(sessionId); +} + +const ERROR_TYPE_TO_TITLE: Record = { + OutOfMemoryError: 'toast.outOfMemoryError', +}; + +const COMMERCIAL_ERROR_TYPE_TO_DESC: Record = { + OutOfMemoryError: 'toast.outOfMemoryErrorDesc', +}; + +export const getTitleFromErrorType = (errorType: string) => { + return t(ERROR_TYPE_TO_TITLE[errorType] ?? 'toast.serverError'); +}; + +type Props = { errorType: string; errorMessage?: string | null; sessionId: string; isLocal: boolean }; + +export default function ErrorToastDescription({ errorType, errorMessage, sessionId, isLocal }: Props) { + const { t } = useTranslation(); + const description = useMemo(() => { + // Special handling for commercial error types + const descriptionTKey = isLocal ? null : COMMERCIAL_ERROR_TYPE_TO_DESC[errorType]; + if (descriptionTKey) { + return t(descriptionTKey); + } + if (errorMessage) { + return upperFirst(errorMessage); + } + }, [errorMessage, errorType, isLocal, t]); + return ( + + {description && {description}} + {!isLocal && ( + + + {t('toast.sessionRef', { sessionId })} + + } + onClick={onCopy.bind(null, sessionId)} + variant="ghost" + sx={sx} + /> + + )} + + ); +} + +const sx = { svg: { fill: 'base.50' } }; diff --git a/invokeai/frontend/web/src/features/toast/ToastWithSessionRefDescription.tsx b/invokeai/frontend/web/src/features/toast/ToastWithSessionRefDescription.tsx deleted file mode 100644 index 9d2999e765c..00000000000 --- a/invokeai/frontend/web/src/features/toast/ToastWithSessionRefDescription.tsx +++ /dev/null @@ -1,30 +0,0 @@ -import { Flex, IconButton, Text } from '@invoke-ai/ui-library'; -import { t } from 'i18next'; -import { PiCopyBold } from 'react-icons/pi'; - -function onCopy(sessionId: string) { - navigator.clipboard.writeText(sessionId); -} - -type Props = { message: string; sessionId: string }; - -export default function ToastWithSessionRefDescription({ message, sessionId }: Props) { - return ( - - {message} - - {t('toast.sessionRef', { sessionId })} - } - onClick={onCopy.bind(null, sessionId)} - variant="ghost" - sx={sx} - /> - - - ); -} - -const sx = { svg: { fill: 'base.50' } };