2
2
from contextlib import suppress
3
3
from threading import BoundedSemaphore , Thread
4
4
from threading import Event as ThreadEvent
5
- from typing import Optional , Union , Callable
5
+ from typing import Callable , Optional , Union
6
6
7
7
from fastapi_events .handlers .local import local_handler
8
8
from fastapi_events .typing import Event as FastAPIEvent
16
16
from invokeai .app .util .profiler import Profiler
17
17
18
18
from ..invoker import Invoker
19
- from .session_processor_base import SessionProcessorBase , SessionRunnerBase , InvocationServices
19
+ from .session_processor_base import InvocationServices , SessionProcessorBase , SessionRunnerBase
20
20
from .session_processor_common import SessionProcessorStatus
21
21
22
22
@@ -36,7 +36,9 @@ def start(self, services: InvocationServices, cancel_event: ThreadEvent):
36
36
self .services = services
37
37
self .cancel_event = cancel_event
38
38
39
- def next_invocation (self , previous_invocation : Optional [BaseInvocation ], queue_item : SessionQueueItem , cancel_event : ThreadEvent ) -> Optional [BaseInvocation ]:
39
+ def next_invocation (
40
+ self , previous_invocation : Optional [BaseInvocation ], queue_item : SessionQueueItem , cancel_event : ThreadEvent
41
+ ) -> Optional [BaseInvocation ]:
40
42
invocation = None
41
43
if not (queue_item .session .is_complete () or cancel_event .is_set ()):
42
44
try :
@@ -57,7 +59,9 @@ def next_invocation(self, previous_invocation: Optional[BaseInvocation], queue_i
57
59
queue_id = queue_item .queue_id ,
58
60
graph_execution_state_id = queue_item .session .id ,
59
61
node = previous_invocation .model_dump () if previous_invocation else {},
60
- source_node_id = queue_item .session .prepared_source_mapping [previous_invocation .id ] if previous_invocation else "" ,
62
+ source_node_id = queue_item .session .prepared_source_mapping [previous_invocation .id ]
63
+ if previous_invocation
64
+ else "" ,
61
65
error_type = exc .__class__ .__name__ ,
62
66
error = node_error ,
63
67
)
@@ -67,7 +71,6 @@ def next_invocation(self, previous_invocation: Optional[BaseInvocation], queue_i
67
71
invocation = None
68
72
return invocation
69
73
70
-
71
74
def run (self , queue_item : SessionQueueItem ):
72
75
"""Run the graph"""
73
76
if not queue_item .session :
@@ -192,12 +195,11 @@ def run_node(self, node_id: str, queue_item: SessionQueueItem):
192
195
)
193
196
194
197
195
-
196
198
class DefaultSessionProcessor (SessionProcessorBase ):
197
-
198
199
def __init__ (self , session_runner : Union [SessionRunnerBase , None ] = None ) -> None :
199
200
super ().__init__ ()
200
201
self .session_runner = session_runner if session_runner else DefaultSessionRunner ()
202
+
201
203
def start (
202
204
self ,
203
205
invoker : Invoker ,
0 commit comments