Skip to content

Commit a3de6b6

Browse files
feat(app): avoid nested cursors in board_image_records service
1 parent e57f0ff commit a3de6b6

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

invokeai/app/services/board_image_records/board_image_records_sqlite.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,18 @@
1212

1313

1414
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
15-
_conn: sqlite3.Connection
16-
_cursor: sqlite3.Cursor
17-
1815
def __init__(self, db: SqliteDatabase) -> None:
1916
super().__init__()
2017
self._conn = db.conn
21-
self._cursor = self._conn.cursor()
2218

2319
def add_image_to_board(
2420
self,
2521
board_id: str,
2622
image_name: str,
2723
) -> None:
2824
try:
29-
self._cursor.execute(
25+
cursor = self._conn.cursor()
26+
cursor.execute(
3027
"""--sql
3128
INSERT INTO board_images (board_id, image_name)
3229
VALUES (?, ?)
@@ -44,7 +41,8 @@ def remove_image_from_board(
4441
image_name: str,
4542
) -> None:
4643
try:
47-
self._cursor.execute(
44+
cursor = self._conn.cursor()
45+
cursor.execute(
4846
"""--sql
4947
DELETE FROM board_images
5048
WHERE image_name = ?;
@@ -63,7 +61,8 @@ def get_images_for_board(
6361
limit: int = 10,
6462
) -> OffsetPaginatedResults[ImageRecord]:
6563
# TODO: this isn't paginated yet?
66-
self._cursor.execute(
64+
cursor = self._conn.cursor()
65+
cursor.execute(
6766
"""--sql
6867
SELECT images.*
6968
FROM board_images
@@ -73,15 +72,15 @@ def get_images_for_board(
7372
""",
7473
(board_id,),
7574
)
76-
result = cast(list[sqlite3.Row], self._cursor.fetchall())
75+
result = cast(list[sqlite3.Row], cursor.fetchall())
7776
images = [deserialize_image_record(dict(r)) for r in result]
7877

79-
self._cursor.execute(
78+
cursor.execute(
8079
"""--sql
8180
SELECT COUNT(*) FROM images WHERE 1=1;
8281
"""
8382
)
84-
count = cast(int, self._cursor.fetchone()[0])
83+
count = cast(int, cursor.fetchone()[0])
8584

8685
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
8786

@@ -128,31 +127,34 @@ def get_all_board_image_names_for_board(
128127
stmt += ";"
129128

130129
# Execute the query
131-
self._cursor.execute(stmt, params)
130+
cursor = self._conn.cursor()
131+
cursor.execute(stmt, params)
132132

133-
result = cast(list[sqlite3.Row], self._cursor.fetchall())
133+
result = cast(list[sqlite3.Row], cursor.fetchall())
134134
image_names = [r[0] for r in result]
135135
return image_names
136136

137137
def get_board_for_image(
138138
self,
139139
image_name: str,
140140
) -> Optional[str]:
141-
self._cursor.execute(
141+
cursor = self._conn.cursor()
142+
cursor.execute(
142143
"""--sql
143144
SELECT board_id
144145
FROM board_images
145146
WHERE image_name = ?;
146147
""",
147148
(image_name,),
148149
)
149-
result = self._cursor.fetchone()
150+
result = cursor.fetchone()
150151
if result is None:
151152
return None
152153
return cast(str, result[0])
153154

154155
def get_image_count_for_board(self, board_id: str) -> int:
155-
self._cursor.execute(
156+
cursor = self._conn.cursor()
157+
cursor.execute(
156158
"""--sql
157159
SELECT COUNT(*)
158160
FROM board_images
@@ -162,5 +164,5 @@ def get_image_count_for_board(self, board_id: str) -> int:
162164
""",
163165
(board_id,),
164166
)
165-
count = cast(int, self._cursor.fetchone()[0])
167+
count = cast(int, cursor.fetchone()[0])
166168
return count

0 commit comments

Comments
 (0)