Skip to content

Commit 1c47dc9

Browse files
feat(app): avoid nested cursors in board_records service
1 parent a3de6b6 commit 1c47dc9

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

invokeai/app/services/board_records/board_records_sqlite.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,14 @@
1818

1919

2020
class SqliteBoardRecordStorage(BoardRecordStorageBase):
21-
_conn: sqlite3.Connection
22-
_cursor: sqlite3.Cursor
23-
2421
def __init__(self, db: SqliteDatabase) -> None:
2522
super().__init__()
2623
self._conn = db.conn
27-
self._cursor = self._conn.cursor()
2824

2925
def delete(self, board_id: str) -> None:
3026
try:
31-
self._cursor.execute(
27+
cursor = self._conn.cursor()
28+
cursor.execute(
3229
"""--sql
3330
DELETE FROM boards
3431
WHERE board_id = ?;
@@ -46,7 +43,8 @@ def save(
4643
) -> BoardRecord:
4744
try:
4845
board_id = uuid_string()
49-
self._cursor.execute(
46+
cursor = self._conn.cursor()
47+
cursor.execute(
5048
"""--sql
5149
INSERT OR IGNORE INTO boards (board_id, board_name)
5250
VALUES (?, ?);
@@ -64,7 +62,8 @@ def get(
6462
board_id: str,
6563
) -> BoardRecord:
6664
try:
67-
self._cursor.execute(
65+
cursor = self._conn.cursor()
66+
cursor.execute(
6867
"""--sql
6968
SELECT *
7069
FROM boards
@@ -73,7 +72,7 @@ def get(
7372
(board_id,),
7473
)
7574

76-
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
75+
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
7776
except sqlite3.Error as e:
7877
raise BoardRecordNotFoundException from e
7978
if result is None:
@@ -86,9 +85,10 @@ def update(
8685
changes: BoardChanges,
8786
) -> BoardRecord:
8887
try:
88+
cursor = self._conn.cursor()
8989
# Change the name of a board
9090
if changes.board_name is not None:
91-
self._cursor.execute(
91+
cursor.execute(
9292
"""--sql
9393
UPDATE boards
9494
SET board_name = ?
@@ -99,7 +99,7 @@ def update(
9999

100100
# Change the cover image of a board
101101
if changes.cover_image_name is not None:
102-
self._cursor.execute(
102+
cursor.execute(
103103
"""--sql
104104
UPDATE boards
105105
SET cover_image_name = ?
@@ -110,7 +110,7 @@ def update(
110110

111111
# Change the archived status of a board
112112
if changes.archived is not None:
113-
self._cursor.execute(
113+
cursor.execute(
114114
"""--sql
115115
UPDATE boards
116116
SET archived = ?
@@ -133,6 +133,8 @@ def get_many(
133133
limit: int = 10,
134134
include_archived: bool = False,
135135
) -> OffsetPaginatedResults[BoardRecord]:
136+
cursor = self._conn.cursor()
137+
136138
# Build base query
137139
base_query = """
138140
SELECT *
@@ -150,9 +152,9 @@ def get_many(
150152
)
151153

152154
# Execute query to fetch boards
153-
self._cursor.execute(final_query, (limit, offset))
155+
cursor.execute(final_query, (limit, offset))
154156

155-
result = cast(list[sqlite3.Row], self._cursor.fetchall())
157+
result = cast(list[sqlite3.Row], cursor.fetchall())
156158
boards = [deserialize_board_record(dict(r)) for r in result]
157159

158160
# Determine count query
@@ -169,15 +171,16 @@ def get_many(
169171
"""
170172

171173
# Execute count query
172-
self._cursor.execute(count_query)
174+
cursor.execute(count_query)
173175

174-
count = cast(int, self._cursor.fetchone()[0])
176+
count = cast(int, cursor.fetchone()[0])
175177

176178
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
177179

178180
def get_all(
179181
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
180182
) -> list[BoardRecord]:
183+
cursor = self._conn.cursor()
181184
if order_by == BoardRecordOrderBy.Name:
182185
base_query = """
183186
SELECT *
@@ -199,9 +202,9 @@ def get_all(
199202
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
200203
)
201204

202-
self._cursor.execute(final_query)
205+
cursor.execute(final_query)
203206

204-
result = cast(list[sqlite3.Row], self._cursor.fetchall())
207+
result = cast(list[sqlite3.Row], cursor.fetchall())
205208
boards = [deserialize_board_record(dict(r)) for r in result]
206209

207210
return boards

0 commit comments

Comments
 (0)