Skip to content

Commit c2a3c66

Browse files
feat(app): avoid nested cursors in workflow_records service
1 parent c0a0d20 commit c2a3c66

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

invokeai/app/services/workflow_records/workflow_records_sqlite.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
2424
def __init__(self, db: SqliteDatabase) -> None:
2525
super().__init__()
2626
self._conn = db.conn
27-
self._cursor = self._conn.cursor()
2827

2928
def start(self, invoker: Invoker) -> None:
3029
self._invoker = invoker
3130
self._sync_default_workflows()
3231

3332
def get(self, workflow_id: str) -> WorkflowRecordDTO:
3433
"""Gets a workflow by ID. Updates the opened_at column."""
35-
self._cursor.execute(
34+
cursor = self._conn.cursor()
35+
cursor.execute(
3636
"""--sql
3737
UPDATE workflow_library
3838
SET opened_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
@@ -41,15 +41,15 @@ def get(self, workflow_id: str) -> WorkflowRecordDTO:
4141
(workflow_id,),
4242
)
4343
self._conn.commit()
44-
self._cursor.execute(
44+
cursor.execute(
4545
"""--sql
4646
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
4747
FROM workflow_library
4848
WHERE workflow_id = ?;
4949
""",
5050
(workflow_id,),
5151
)
52-
row = self._cursor.fetchone()
52+
row = cursor.fetchone()
5353
if row is None:
5454
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
5555
return WorkflowRecordDTO.from_dict(dict(row))
@@ -59,7 +59,8 @@ def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
5959
# Only user workflows may be created by this method
6060
assert workflow.meta.category is WorkflowCategory.User
6161
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
62-
self._cursor.execute(
62+
cursor = self._conn.cursor()
63+
cursor.execute(
6364
"""--sql
6465
INSERT OR IGNORE INTO workflow_library (
6566
workflow_id,
@@ -77,7 +78,8 @@ def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
7778

7879
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
7980
try:
80-
self._cursor.execute(
81+
cursor = self._conn.cursor()
82+
cursor.execute(
8183
"""--sql
8284
UPDATE workflow_library
8385
SET workflow = ?
@@ -93,7 +95,8 @@ def update(self, workflow: Workflow) -> WorkflowRecordDTO:
9395

9496
def delete(self, workflow_id: str) -> None:
9597
try:
96-
self._cursor.execute(
98+
cursor = self._conn.cursor()
99+
cursor.execute(
97100
"""--sql
98101
DELETE from workflow_library
99102
WHERE workflow_id = ? AND category = 'user';
@@ -149,12 +152,13 @@ def get_many(
149152
main_query += " LIMIT ? OFFSET ?"
150153
main_params.extend([per_page, page * per_page])
151154

152-
self._cursor.execute(main_query, main_params)
153-
rows = self._cursor.fetchall()
155+
cursor = self._conn.cursor()
156+
cursor.execute(main_query, main_params)
157+
rows = cursor.fetchall()
154158
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
155159

156-
self._cursor.execute(count_query, count_params)
157-
total = self._cursor.fetchone()[0]
160+
cursor.execute(count_query, count_params)
161+
total = cursor.fetchone()[0]
158162

159163
if per_page:
160164
pages = total // per_page + (total % per_page > 0)
@@ -193,14 +197,15 @@ def _sync_default_workflows(self) -> None:
193197
workflows.append(workflow)
194198
# Only default workflows may be managed by this method
195199
assert all(w.meta.category is WorkflowCategory.Default for w in workflows)
196-
self._cursor.execute(
200+
cursor = self._conn.cursor()
201+
cursor.execute(
197202
"""--sql
198203
DELETE FROM workflow_library
199204
WHERE category = 'default';
200205
"""
201206
)
202207
for w in workflows:
203-
self._cursor.execute(
208+
cursor.execute(
204209
"""--sql
205210
INSERT OR REPLACE INTO workflow_library (
206211
workflow_id,

0 commit comments

Comments
 (0)