Skip to content

Commit c117ffd

Browse files
feat(app): make things in session runner private
1 parent 06334c0 commit c117ffd

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

invokeai/app/services/session_processor/session_processor_default.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def __init__(
5252

5353
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
5454
"""Start the session runner"""
55-
self.services = services
56-
self.cancel_event = cancel_event
57-
self.profiler = profiler
55+
self._services = services
56+
self._cancel_event = cancel_event
57+
self._profiler = profiler
5858

5959
def run(self, queue_item: SessionQueueItem):
6060
"""Run the graph"""
@@ -64,33 +64,33 @@ def run(self, queue_item: SessionQueueItem):
6464

6565
while True:
6666
invocation = queue_item.session.next()
67-
if invocation is None or self.cancel_event.is_set():
67+
if invocation is None or self._cancel_event.is_set():
6868
break
6969
self.run_node(invocation, queue_item)
70-
if queue_item.session.is_complete() or self.cancel_event.is_set():
70+
if queue_item.session.is_complete() or self._cancel_event.is_set():
7171
break
7272

7373
self._on_after_run_session(queue_item=queue_item)
7474

7575
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
7676
# If profiling is enabled, start the profiler
77-
if self.profiler is not None:
78-
self.profiler.start(profile_id=queue_item.session_id)
77+
if self._profiler is not None:
78+
self._profiler.start(profile_id=queue_item.session_id)
7979

8080
if self.on_before_run_session:
8181
self.on_before_run_session(queue_item=queue_item)
8282

8383
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
8484
# If we are profiling, stop the profiler and dump the profile & stats
85-
if self.profiler is not None:
86-
profile_path = self.profiler.stop()
85+
if self._profiler is not None:
86+
profile_path = self._profiler.stop()
8787
stats_path = profile_path.with_suffix(".json")
88-
self.services.performance_statistics.dump_stats(
88+
self._services.performance_statistics.dump_stats(
8989
graph_execution_state_id=queue_item.session.id, output_path=stats_path
9090
)
9191

9292
# Send complete event
93-
self.services.events.emit_graph_execution_complete(
93+
self._services.events.emit_graph_execution_complete(
9494
queue_batch_id=queue_item.batch_id,
9595
queue_item_id=queue_item.item_id,
9696
queue_id=queue_item.queue_id,
@@ -100,16 +100,16 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
100100
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
101101
# we don't care about that - suppress the error.
102102
with suppress(GESStatsNotFoundError):
103-
self.services.performance_statistics.log_stats(queue_item.session.id)
104-
self.services.performance_statistics.reset_stats()
103+
self._services.performance_statistics.log_stats(queue_item.session.id)
104+
self._services.performance_statistics.reset_stats()
105105

106106
if self.on_after_run_session:
107107
self.on_after_run_session(queue_item)
108108

109109
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
110110
"""Run before a node is executed"""
111111
# Send starting event
112-
self.services.events.emit_invocation_started(
112+
self._services.events.emit_invocation_started(
113113
queue_batch_id=queue_item.batch_id,
114114
queue_item_id=queue_item.item_id,
115115
queue_id=queue_item.queue_id,
@@ -126,7 +126,7 @@ def _on_after_run_node(
126126
):
127127
"""Run after a node is executed"""
128128
# Send complete event on successful runs
129-
self.services.events.emit_invocation_complete(
129+
self._services.events.emit_invocation_complete(
130130
queue_batch_id=queue_item.batch_id,
131131
queue_item_id=queue_item.item_id,
132132
queue_id=queue_item.queue_id,
@@ -150,13 +150,13 @@ def _on_node_error(
150150
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
151151

152152
queue_item.session.set_node_error(invocation.id, stacktrace)
153-
self.services.logger.error(
153+
self._services.logger.error(
154154
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{exc_type}"
155155
)
156-
self.services.logger.error(stacktrace)
156+
self._services.logger.error(stacktrace)
157157

158158
# Send error event
159-
self.services.events.emit_invocation_error(
159+
self._services.events.emit_invocation_error(
160160
queue_batch_id=queue_item.session_id,
161161
queue_item_id=queue_item.item_id,
162162
queue_id=queue_item.queue_id,
@@ -176,7 +176,7 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
176176
"""Run a single node in the graph"""
177177
try:
178178
# Any unhandled exception is an invocation error & will fail the graph
179-
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
179+
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
180180
self._on_before_run_node(invocation, queue_item)
181181

182182
data = InvocationContextData(
@@ -186,12 +186,12 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
186186
)
187187
context = build_invocation_context(
188188
data=data,
189-
services=self.services,
190-
cancel_event=self.cancel_event,
189+
services=self._services,
190+
cancel_event=self._cancel_event,
191191
)
192192

193193
# Invoke the node
194-
outputs = invocation.invoke_internal(context=context, services=self.services)
194+
outputs = invocation.invoke_internal(context=context, services=self._services)
195195
# Save outputs and history
196196
queue_item.session.complete(invocation.id, outputs)
197197

0 commit comments

Comments
 (0)