Skip to content

Commit f7c356d

Browse files
feat(app): iterate on processor split 2
- Use protocol to define callbacks, this allows them to have kwargs - Shuffle the profiler around a bit - Move `thread_limit` and `polling_interval` to `__init__`; `start` is called programmatically and will never get these args in practice
1 parent efb069d commit f7c356d

File tree

3 files changed

+80
-37
lines changed

3 files changed

+80
-37
lines changed

invokeai/app/api/dependencies.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def on_before_run_node(invocation, queue_item):
112112
print("BEFORE RUN NODE", invocation.id)
113113
return True
114114

115-
def on_after_run_node(invocation, queue_item, outputs):
115+
def on_after_run_node(invocation, queue_item, output):
116116
print("AFTER RUN NODE", invocation.id)
117117
return True
118118

@@ -124,17 +124,17 @@ def on_after_run_session(queue_item):
124124
print("AFTER RUN SESSION", queue_item.item_id)
125125
return True
126126

127-
def on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback):
127+
def on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item=None):
128128
print("NON FATAL PROCESSOR ERROR", exc_value)
129129
return True
130130

131131
session_processor = DefaultSessionProcessor(
132132
DefaultSessionRunner(
133-
on_before_run_session,
134-
on_before_run_node,
135-
on_after_run_node,
136-
on_node_error,
137-
on_after_run_session,
133+
on_before_run_session=on_before_run_session,
134+
on_before_run_node=on_before_run_node,
135+
on_after_run_node=on_after_run_node,
136+
on_node_error=on_node_error,
137+
on_after_run_session=on_after_run_session,
138138
),
139139
on_non_fatal_processor_error,
140140
)

invokeai/app/services/session_processor/session_processor_base.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from abc import ABC, abstractmethod
22
from threading import Event
3+
from types import TracebackType
4+
from typing import Optional, Protocol
35

4-
from invokeai.app.invocations.baseinvocation import BaseInvocation
6+
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
57
from invokeai.app.services.invocation_services import InvocationServices
68
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
79
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
10+
from invokeai.app.util.profiler import Profiler
811

912

1013
class SessionRunnerBase(ABC):
@@ -13,7 +16,7 @@ class SessionRunnerBase(ABC):
1316
"""
1417

1518
@abstractmethod
16-
def start(self, services: InvocationServices, cancel_event: Event) -> None:
19+
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
1720
"""Starts the session runner"""
1821
pass
1922

@@ -51,3 +54,42 @@ def pause(self) -> SessionProcessorStatus:
5154
def get_status(self) -> SessionProcessorStatus:
5255
"""Gets the status of the session processor"""
5356
pass
57+
58+
59+
class OnBeforeRunNode(Protocol):
60+
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> bool: ...
61+
62+
63+
class OnAfterRunNode(Protocol):
64+
def __call__(
65+
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
66+
) -> bool: ...
67+
68+
69+
class OnNodeError(Protocol):
70+
def __call__(
71+
self,
72+
invocation: BaseInvocation,
73+
queue_item: SessionQueueItem,
74+
exc_type: type,
75+
exc_value: BaseException,
76+
exc_traceback: TracebackType,
77+
) -> bool: ...
78+
79+
80+
class OnBeforeRunSession(Protocol):
81+
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
82+
83+
84+
class OnAfterRunSession(Protocol):
85+
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
86+
87+
88+
class OnNonFatalProcessorError(Protocol):
89+
def __call__(
90+
self,
91+
exc_type: type,
92+
exc_value: BaseException,
93+
exc_traceback: TracebackType,
94+
queue_item: Optional[SessionQueueItem] = None,
95+
) -> bool: ...

invokeai/app/services/session_processor/session_processor_default.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,22 @@
33
from threading import BoundedSemaphore, Thread
44
from threading import Event as ThreadEvent
55
from types import TracebackType
6-
from typing import Callable, Optional, TypeAlias
6+
from typing import Optional
77

88
from fastapi_events.handlers.local import local_handler
99
from fastapi_events.typing import Event as FastAPIEvent
1010

1111
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
1212
from invokeai.app.services.events.events_base import EventServiceBase
1313
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
14+
from invokeai.app.services.session_processor.session_processor_base import (
15+
OnAfterRunNode,
16+
OnAfterRunSession,
17+
OnBeforeRunNode,
18+
OnBeforeRunSession,
19+
OnNodeError,
20+
OnNonFatalProcessorError,
21+
)
1422
from invokeai.app.services.session_processor.session_processor_common import CanceledException
1523
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
1624
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
@@ -20,13 +28,6 @@
2028
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
2129
from .session_processor_common import SessionProcessorStatus
2230

23-
OnBeforeRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem], bool]
24-
OnAfterRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, BaseInvocationOutput], bool]
25-
OnNodeError: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, type, BaseException, TracebackType], bool]
26-
OnBeforeRunSession: TypeAlias = Callable[[SessionQueueItem], bool]
27-
OnAfterRunSession: TypeAlias = Callable[[SessionQueueItem], bool]
28-
OnNonFatalProcessorError: TypeAlias = Callable[[Optional[SessionQueueItem], type, BaseException, TracebackType], bool]
29-
3031

3132
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
3233
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
@@ -49,37 +50,40 @@ def __init__(
4950
self.on_node_error = on_node_error
5051
self.on_after_run_session = on_after_run_session
5152

52-
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
53+
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
5354
"""Start the session runner"""
5455
self.services = services
5556
self.cancel_event = cancel_event
57+
self.profiler = profiler
5658

57-
def run(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None):
59+
def run(self, queue_item: SessionQueueItem):
5860
"""Run the graph"""
5961
# Loop over invocations until the session is complete or canceled
6062

6163
self._on_before_run_session(queue_item=queue_item)
64+
6265
while True:
6366
invocation = queue_item.session.next()
6467
if invocation is None or self.cancel_event.is_set():
6568
break
6669
self.run_node(invocation, queue_item)
6770
if queue_item.session.is_complete() or self.cancel_event.is_set():
6871
break
72+
6973
self._on_after_run_session(queue_item=queue_item)
7074

71-
def _on_before_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None:
75+
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
7276
# If profiling is enabled, start the profiler
73-
if profiler is not None:
74-
profiler.start(profile_id=queue_item.session_id)
77+
if self.profiler is not None:
78+
self.profiler.start(profile_id=queue_item.session_id)
7579

7680
if self.on_before_run_session:
77-
self.on_before_run_session(queue_item)
81+
self.on_before_run_session(queue_item=queue_item)
7882

79-
def _on_after_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None:
83+
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
8084
# If we are profiling, stop the profiler and dump the profile & stats
81-
if profiler:
82-
profile_path = profiler.stop()
85+
if self.profiler is not None:
86+
profile_path = self.profiler.stop()
8387
stats_path = profile_path.with_suffix(".json")
8488
self.services.performance_statistics.dump_stats(
8589
graph_execution_state_id=queue_item.session.id, output_path=stats_path
@@ -221,11 +225,15 @@ def __init__(
221225
self,
222226
session_runner: Optional[SessionRunnerBase] = None,
223227
on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None,
228+
thread_limit: int = 1,
229+
polling_interval: int = 1,
224230
) -> None:
225231
super().__init__()
226232

227233
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
228234
self.on_non_fatal_processor_error = on_non_fatal_processor_error
235+
self._thread_limit = thread_limit
236+
self._polling_interval = polling_interval
229237

230238
def _on_non_fatal_processor_error(
231239
self,
@@ -243,14 +251,9 @@ def _on_non_fatal_processor_error(
243251
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
244252

245253
if self.on_non_fatal_processor_error:
246-
self.on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback)
254+
self.on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item)
247255

248-
def start(
249-
self,
250-
invoker: Invoker,
251-
thread_limit: int = 1,
252-
polling_interval: int = 1,
253-
) -> None:
256+
def start(self, invoker: Invoker) -> None:
254257
self._invoker: Invoker = invoker
255258
self._queue_item: Optional[SessionQueueItem] = None
256259
self._invocation: Optional[BaseInvocation] = None
@@ -262,9 +265,7 @@ def start(
262265

263266
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
264267

265-
self._thread_limit = thread_limit
266-
self._thread_semaphore = BoundedSemaphore(thread_limit)
267-
self._polling_interval = polling_interval
268+
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
268269

269270
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
270271
# the profiler will create a new profile for each session.
@@ -278,7 +279,7 @@ def start(
278279
else None
279280
)
280281

281-
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
282+
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
282283
self._thread = Thread(
283284
name="session_processor",
284285
target=self._process,

0 commit comments

Comments
 (0)