Skip to content

Commit 657095d

Browse files
feat(app): avoid nested cursors in image_records service
1 parent 1c47dc9 commit 657095d

File tree

1 file changed

+36
-28
lines changed

1 file changed

+36
-28
lines changed

invokeai/app/services/image_records/image_records_sqlite.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,22 @@
2121

2222

2323
class SqliteImageRecordStorage(ImageRecordStorageBase):
24-
_conn: sqlite3.Connection
25-
_cursor: sqlite3.Cursor
26-
2724
def __init__(self, db: SqliteDatabase) -> None:
2825
super().__init__()
2926
self._conn = db.conn
30-
self._cursor = self._conn.cursor()
3127

3228
def get(self, image_name: str) -> ImageRecord:
3329
try:
34-
self._cursor.execute(
30+
cursor = self._conn.cursor()
31+
cursor.execute(
3532
f"""--sql
3633
SELECT {IMAGE_DTO_COLS} FROM images
3734
WHERE image_name = ?;
3835
""",
3936
(image_name,),
4037
)
4138

42-
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
39+
result = cast(Optional[sqlite3.Row], cursor.fetchone())
4340
except sqlite3.Error as e:
4441
raise ImageRecordNotFoundException from e
4542

@@ -50,15 +47,16 @@ def get(self, image_name: str) -> ImageRecord:
5047

5148
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
5249
try:
53-
self._cursor.execute(
50+
cursor = self._conn.cursor()
51+
cursor.execute(
5452
"""--sql
5553
SELECT metadata FROM images
5654
WHERE image_name = ?;
5755
""",
5856
(image_name,),
5957
)
6058

61-
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
59+
result = cast(Optional[sqlite3.Row], cursor.fetchone())
6260

6361
if not result:
6462
raise ImageRecordNotFoundException
@@ -75,9 +73,10 @@ def update(
7573
changes: ImageRecordChanges,
7674
) -> None:
7775
try:
76+
cursor = self._conn.cursor()
7877
# Change the category of the image
7978
if changes.image_category is not None:
80-
self._cursor.execute(
79+
cursor.execute(
8180
"""--sql
8281
UPDATE images
8382
SET image_category = ?
@@ -88,7 +87,7 @@ def update(
8887

8988
# Change the session associated with the image
9089
if changes.session_id is not None:
91-
self._cursor.execute(
90+
cursor.execute(
9291
"""--sql
9392
UPDATE images
9493
SET session_id = ?
@@ -99,7 +98,7 @@ def update(
9998

10099
# Change the image's `is_intermediate`` flag
101100
if changes.is_intermediate is not None:
102-
self._cursor.execute(
101+
cursor.execute(
103102
"""--sql
104103
UPDATE images
105104
SET is_intermediate = ?
@@ -110,7 +109,7 @@ def update(
110109

111110
# Change the image's `starred`` state
112111
if changes.starred is not None:
113-
self._cursor.execute(
112+
cursor.execute(
114113
"""--sql
115114
UPDATE images
116115
SET starred = ?
@@ -136,6 +135,8 @@ def get_many(
136135
board_id: Optional[str] = None,
137136
search_term: Optional[str] = None,
138137
) -> OffsetPaginatedResults[ImageRecord]:
138+
cursor = self._conn.cursor()
139+
139140
# Manually build two queries - one for the count, one for the records
140141
count_query = """--sql
141142
SELECT COUNT(*)
@@ -216,21 +217,22 @@ def get_many(
216217
images_params.extend([limit, offset])
217218

218219
# Build the list of images, deserializing each row
219-
self._cursor.execute(images_query, images_params)
220-
result = cast(list[sqlite3.Row], self._cursor.fetchall())
220+
cursor.execute(images_query, images_params)
221+
result = cast(list[sqlite3.Row], cursor.fetchall())
221222
images = [deserialize_image_record(dict(r)) for r in result]
222223

223224
# Set up and execute the count query, without pagination
224225
count_query += query_conditions + ";"
225226
count_params = query_params.copy()
226-
self._cursor.execute(count_query, count_params)
227-
count = cast(int, self._cursor.fetchone()[0])
227+
cursor.execute(count_query, count_params)
228+
count = cast(int, cursor.fetchone()[0])
228229

229230
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
230231

231232
def delete(self, image_name: str) -> None:
232233
try:
233-
self._cursor.execute(
234+
cursor = self._conn.cursor()
235+
cursor.execute(
234236
"""--sql
235237
DELETE FROM images
236238
WHERE image_name = ?;
@@ -244,41 +246,45 @@ def delete(self, image_name: str) -> None:
244246

245247
def delete_many(self, image_names: list[str]) -> None:
246248
try:
249+
cursor = self._conn.cursor()
250+
247251
placeholders = ",".join("?" for _ in image_names)
248252

249253
# Construct the SQLite query with the placeholders
250254
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
251255

252256
# Execute the query with the list of IDs as parameters
253-
self._cursor.execute(query, image_names)
257+
cursor.execute(query, image_names)
254258

255259
self._conn.commit()
256260
except sqlite3.Error as e:
257261
self._conn.rollback()
258262
raise ImageRecordDeleteException from e
259263

260264
def get_intermediates_count(self) -> int:
261-
self._cursor.execute(
265+
cursor = self._conn.cursor()
266+
cursor.execute(
262267
"""--sql
263268
SELECT COUNT(*) FROM images
264269
WHERE is_intermediate = TRUE;
265270
"""
266271
)
267-
count = cast(int, self._cursor.fetchone()[0])
272+
count = cast(int, cursor.fetchone()[0])
268273
self._conn.commit()
269274
return count
270275

271276
def delete_intermediates(self) -> list[str]:
272277
try:
273-
self._cursor.execute(
278+
cursor = self._conn.cursor()
279+
cursor.execute(
274280
"""--sql
275281
SELECT image_name FROM images
276282
WHERE is_intermediate = TRUE;
277283
"""
278284
)
279-
result = cast(list[sqlite3.Row], self._cursor.fetchall())
285+
result = cast(list[sqlite3.Row], cursor.fetchall())
280286
image_names = [r[0] for r in result]
281-
self._cursor.execute(
287+
cursor.execute(
282288
"""--sql
283289
DELETE FROM images
284290
WHERE is_intermediate = TRUE;
@@ -305,7 +311,8 @@ def save(
305311
metadata: Optional[str] = None,
306312
) -> datetime:
307313
try:
308-
self._cursor.execute(
314+
cursor = self._conn.cursor()
315+
cursor.execute(
309316
"""--sql
310317
INSERT OR IGNORE INTO images (
311318
image_name,
@@ -338,7 +345,7 @@ def save(
338345
)
339346
self._conn.commit()
340347

341-
self._cursor.execute(
348+
cursor.execute(
342349
"""--sql
343350
SELECT created_at
344351
FROM images
@@ -347,15 +354,16 @@ def save(
347354
(image_name,),
348355
)
349356

350-
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
357+
created_at = datetime.fromisoformat(cursor.fetchone()[0])
351358

352359
return created_at
353360
except sqlite3.Error as e:
354361
self._conn.rollback()
355362
raise ImageRecordSaveException from e
356363

357364
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
358-
self._cursor.execute(
365+
cursor = self._conn.cursor()
366+
cursor.execute(
359367
"""--sql
360368
SELECT images.*
361369
FROM images
@@ -368,7 +376,7 @@ def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord
368376
(board_id,),
369377
)
370378

371-
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
379+
result = cast(Optional[sqlite3.Row], cursor.fetchone())
372380

373381
if result is None:
374382
return None

0 commit comments

Comments
 (0)