Skip to content

Commit c0a0d20

Browse files
feat(app): avoid nested cursors in style_preset_records service
1 parent 028d8d8 commit c0a0d20

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

invokeai/app/services/style_preset_records/style_preset_records_sqlite.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,32 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
1818
def __init__(self, db: SqliteDatabase) -> None:
1919
super().__init__()
2020
self._conn = db.conn
21-
self._cursor = self._conn.cursor()
2221

2322
def start(self, invoker: Invoker) -> None:
2423
self._invoker = invoker
2524
self._sync_default_style_presets()
2625

2726
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
2827
"""Gets a style preset by ID."""
29-
self._cursor.execute(
28+
cursor = self._conn.cursor()
29+
cursor.execute(
3030
"""--sql
3131
SELECT *
3232
FROM style_presets
3333
WHERE id = ?;
3434
""",
3535
(style_preset_id,),
3636
)
37-
row = self._cursor.fetchone()
37+
row = cursor.fetchone()
3838
if row is None:
3939
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
4040
return StylePresetRecordDTO.from_dict(dict(row))
4141

4242
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
4343
style_preset_id = uuid_string()
4444
try:
45-
self._cursor.execute(
45+
cursor = self._conn.cursor()
46+
cursor.execute(
4647
"""--sql
4748
INSERT OR IGNORE INTO style_presets (
4849
id,
@@ -68,10 +69,11 @@ def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
6869
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
6970
style_preset_ids = []
7071
try:
72+
cursor = self._conn.cursor()
7173
for style_preset in style_presets:
7274
style_preset_id = uuid_string()
7375
style_preset_ids.append(style_preset_id)
74-
self._cursor.execute(
76+
cursor.execute(
7577
"""--sql
7678
INSERT OR IGNORE INTO style_presets (
7779
id,
@@ -97,9 +99,10 @@ def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
9799

98100
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
99101
try:
102+
cursor = self._conn.cursor()
100103
# Change the name of a style preset
101104
if changes.name is not None:
102-
self._cursor.execute(
105+
cursor.execute(
103106
"""--sql
104107
UPDATE style_presets
105108
SET name = ?
@@ -110,7 +113,7 @@ def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePres
110113

111114
# Change the preset data for a style preset
112115
if changes.preset_data is not None:
113-
self._cursor.execute(
116+
cursor.execute(
114117
"""--sql
115118
UPDATE style_presets
116119
SET preset_data = ?
@@ -127,7 +130,8 @@ def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePres
127130

128131
def delete(self, style_preset_id: str) -> None:
129132
try:
130-
self._cursor.execute(
133+
cursor = self._conn.cursor()
134+
cursor.execute(
131135
"""--sql
132136
DELETE from style_presets
133137
WHERE id = ?;
@@ -152,12 +156,13 @@ def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]
152156

153157
main_query += "ORDER BY LOWER(name) ASC"
154158

159+
cursor = self._conn.cursor()
155160
if type is not None:
156-
self._cursor.execute(main_query, (type,))
161+
cursor.execute(main_query, (type,))
157162
else:
158-
self._cursor.execute(main_query)
163+
cursor.execute(main_query)
159164

160-
rows = self._cursor.fetchall()
165+
rows = cursor.fetchall()
161166
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
162167

163168
return style_presets
@@ -167,7 +172,8 @@ def _sync_default_style_presets(self) -> None:
167172

168173
# First delete all existing default style presets
169174
try:
170-
self._cursor.execute(
175+
cursor = self._conn.cursor()
176+
cursor.execute(
171177
"""--sql
172178
DELETE FROM style_presets
173179
WHERE type = "default";

0 commit comments

Comments
 (0)