3
3
from threading import BoundedSemaphore , Thread
4
4
from threading import Event as ThreadEvent
5
5
from types import TracebackType
6
- from typing import Callable , Optional , TypeAlias
6
+ from typing import Optional
7
7
8
8
from fastapi_events .handlers .local import local_handler
9
9
from fastapi_events .typing import Event as FastAPIEvent
10
10
11
11
from invokeai .app .invocations .baseinvocation import BaseInvocation , BaseInvocationOutput
12
12
from invokeai .app .services .events .events_base import EventServiceBase
13
13
from invokeai .app .services .invocation_stats .invocation_stats_common import GESStatsNotFoundError
14
+ from invokeai .app .services .session_processor .session_processor_base import (
15
+ OnAfterRunNode ,
16
+ OnAfterRunSession ,
17
+ OnBeforeRunNode ,
18
+ OnBeforeRunSession ,
19
+ OnNodeError ,
20
+ OnNonFatalProcessorError ,
21
+ )
14
22
from invokeai .app .services .session_processor .session_processor_common import CanceledException
15
23
from invokeai .app .services .session_queue .session_queue_common import SessionQueueItem
16
24
from invokeai .app .services .shared .invocation_context import InvocationContextData , build_invocation_context
20
28
from .session_processor_base import InvocationServices , SessionProcessorBase , SessionRunnerBase
21
29
from .session_processor_common import SessionProcessorStatus
22
30
23
- OnBeforeRunNode : TypeAlias = Callable [[BaseInvocation , SessionQueueItem ], bool ]
24
- OnAfterRunNode : TypeAlias = Callable [[BaseInvocation , SessionQueueItem , BaseInvocationOutput ], bool ]
25
- OnNodeError : TypeAlias = Callable [[BaseInvocation , SessionQueueItem , type , BaseException , TracebackType ], bool ]
26
- OnBeforeRunSession : TypeAlias = Callable [[SessionQueueItem ], bool ]
27
- OnAfterRunSession : TypeAlias = Callable [[SessionQueueItem ], bool ]
28
- OnNonFatalProcessorError : TypeAlias = Callable [[Optional [SessionQueueItem ], type , BaseException , TracebackType ], bool ]
29
-
30
31
31
32
def get_stacktrace (exc_type : type , exc_value : BaseException , exc_traceback : TracebackType ) -> str :
32
33
return "" .join (traceback .format_exception (exc_type , exc_value , exc_traceback ))
@@ -49,37 +50,40 @@ def __init__(
49
50
self .on_node_error = on_node_error
50
51
self .on_after_run_session = on_after_run_session
51
52
52
- def start (self , services : InvocationServices , cancel_event : ThreadEvent ):
53
+ def start (self , services : InvocationServices , cancel_event : ThreadEvent , profiler : Optional [ Profiler ] = None ):
53
54
"""Start the session runner"""
54
55
self .services = services
55
56
self .cancel_event = cancel_event
57
+ self .profiler = profiler
56
58
57
- def run (self , queue_item : SessionQueueItem , profiler : Optional [ Profiler ] = None ):
59
+ def run (self , queue_item : SessionQueueItem ):
58
60
"""Run the graph"""
59
61
# Loop over invocations until the session is complete or canceled
60
62
61
63
self ._on_before_run_session (queue_item = queue_item )
64
+
62
65
while True :
63
66
invocation = queue_item .session .next ()
64
67
if invocation is None or self .cancel_event .is_set ():
65
68
break
66
69
self .run_node (invocation , queue_item )
67
70
if queue_item .session .is_complete () or self .cancel_event .is_set ():
68
71
break
72
+
69
73
self ._on_after_run_session (queue_item = queue_item )
70
74
71
- def _on_before_run_session (self , queue_item : SessionQueueItem , profiler : Optional [ Profiler ] = None ) -> None :
75
+ def _on_before_run_session (self , queue_item : SessionQueueItem ) -> None :
72
76
# If profiling is enabled, start the profiler
73
- if profiler is not None :
74
- profiler .start (profile_id = queue_item .session_id )
77
+ if self . profiler is not None :
78
+ self . profiler .start (profile_id = queue_item .session_id )
75
79
76
80
if self .on_before_run_session :
77
- self .on_before_run_session (queue_item )
81
+ self .on_before_run_session (queue_item = queue_item )
78
82
79
- def _on_after_run_session (self , queue_item : SessionQueueItem , profiler : Optional [ Profiler ] = None ) -> None :
83
+ def _on_after_run_session (self , queue_item : SessionQueueItem ) -> None :
80
84
# If we are profiling, stop the profiler and dump the profile & stats
81
- if profiler :
82
- profile_path = profiler .stop ()
85
+ if self . profiler is not None :
86
+ profile_path = self . profiler .stop ()
83
87
stats_path = profile_path .with_suffix (".json" )
84
88
self .services .performance_statistics .dump_stats (
85
89
graph_execution_state_id = queue_item .session .id , output_path = stats_path
@@ -221,11 +225,15 @@ def __init__(
221
225
self ,
222
226
session_runner : Optional [SessionRunnerBase ] = None ,
223
227
on_non_fatal_processor_error : Optional [OnNonFatalProcessorError ] = None ,
228
+ thread_limit : int = 1 ,
229
+ polling_interval : int = 1 ,
224
230
) -> None :
225
231
super ().__init__ ()
226
232
227
233
self .session_runner = session_runner if session_runner else DefaultSessionRunner ()
228
234
self .on_non_fatal_processor_error = on_non_fatal_processor_error
235
+ self ._thread_limit = thread_limit
236
+ self ._polling_interval = polling_interval
229
237
230
238
def _on_non_fatal_processor_error (
231
239
self ,
@@ -243,14 +251,9 @@ def _on_non_fatal_processor_error(
243
251
self ._invoker .services .session_queue .cancel_queue_item (queue_item .item_id , error = stacktrace )
244
252
245
253
if self .on_non_fatal_processor_error :
246
- self .on_non_fatal_processor_error (queue_item , exc_type , exc_value , exc_traceback )
254
+ self .on_non_fatal_processor_error (exc_type , exc_value , exc_traceback , queue_item )
247
255
248
- def start (
249
- self ,
250
- invoker : Invoker ,
251
- thread_limit : int = 1 ,
252
- polling_interval : int = 1 ,
253
- ) -> None :
256
+ def start (self , invoker : Invoker ) -> None :
254
257
self ._invoker : Invoker = invoker
255
258
self ._queue_item : Optional [SessionQueueItem ] = None
256
259
self ._invocation : Optional [BaseInvocation ] = None
@@ -262,9 +265,7 @@ def start(
262
265
263
266
local_handler .register (event_name = EventServiceBase .queue_event , _func = self ._on_queue_event )
264
267
265
- self ._thread_limit = thread_limit
266
- self ._thread_semaphore = BoundedSemaphore (thread_limit )
267
- self ._polling_interval = polling_interval
268
+ self ._thread_semaphore = BoundedSemaphore (self ._thread_limit )
268
269
269
270
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
270
271
# the profiler will create a new profile for each session.
@@ -278,7 +279,7 @@ def start(
278
279
else None
279
280
)
280
281
281
- self .session_runner .start (services = invoker .services , cancel_event = self ._cancel_event )
282
+ self .session_runner .start (services = invoker .services , cancel_event = self ._cancel_event , profiler = self . _profiler )
282
283
self ._thread = Thread (
283
284
name = "session_processor" ,
284
285
target = self ._process ,
0 commit comments