Skip to content

Commit 028d8d8

Browse files
feat(app): avoid nested cursors in model_records service
1 parent 657095d commit 028d8d8

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

invokeai/app/services/model_records/model_records_sql.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def __init__(self, db: SqliteDatabase, logger: logging.Logger):
7878
"""
7979
super().__init__()
8080
self._db = db
81-
self._cursor = db.conn.cursor()
8281
self._logger = logger
8382

8483
@property
@@ -97,7 +96,8 @@ def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
9796
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
9897
"""
9998
try:
100-
self._cursor.execute(
99+
cursor = self._db.conn.cursor()
100+
cursor.execute(
101101
"""--sql
102102
INSERT INTO models (
103103
id,
@@ -139,14 +139,15 @@ def del_model(self, key: str) -> None:
139139
Can raise an UnknownModelException
140140
"""
141141
try:
142-
self._cursor.execute(
142+
cursor = self._db.conn.cursor()
143+
cursor.execute(
143144
"""--sql
144145
DELETE FROM models
145146
WHERE id=?;
146147
""",
147148
(key,),
148149
)
149-
if self._cursor.rowcount == 0:
150+
if cursor.rowcount == 0:
150151
raise UnknownModelException("model not found")
151152
self._db.conn.commit()
152153
except sqlite3.Error as e:
@@ -163,7 +164,8 @@ def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
163164
json_serialized = record.model_dump_json()
164165

165166
try:
166-
self._cursor.execute(
167+
cursor = self._db.conn.cursor()
168+
cursor.execute(
167169
"""--sql
168170
UPDATE models
169171
SET
@@ -172,7 +174,7 @@ def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
172174
""",
173175
(json_serialized, key),
174176
)
175-
if self._cursor.rowcount == 0:
177+
if cursor.rowcount == 0:
176178
raise UnknownModelException("model not found")
177179
self._db.conn.commit()
178180
except sqlite3.Error as e:
@@ -189,28 +191,30 @@ def get_model(self, key: str) -> AnyModelConfig:
189191
190192
Exceptions: UnknownModelException
191193
"""
192-
self._cursor.execute(
194+
cursor = self._db.conn.cursor()
195+
cursor.execute(
193196
"""--sql
194197
SELECT config, strftime('%s',updated_at) FROM models
195198
WHERE id=?;
196199
""",
197200
(key,),
198201
)
199-
rows = self._cursor.fetchone()
202+
rows = cursor.fetchone()
200203
if not rows:
201204
raise UnknownModelException("model not found")
202205
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
203206
return model
204207

205208
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
206-
self._cursor.execute(
209+
cursor = self._db.conn.cursor()
210+
cursor.execute(
207211
"""--sql
208212
SELECT config, strftime('%s',updated_at) FROM models
209213
WHERE hash=?;
210214
""",
211215
(hash,),
212216
)
213-
rows = self._cursor.fetchone()
217+
rows = cursor.fetchone()
214218
if not rows:
215219
raise UnknownModelException("model not found")
216220
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
@@ -222,14 +226,15 @@ def exists(self, key: str) -> bool:
222226
223227
:param key: Unique key for the model to be deleted
224228
"""
225-
self._cursor.execute(
229+
cursor = self._db.conn.cursor()
230+
cursor.execute(
226231
"""--sql
227232
select count(*) FROM models
228233
WHERE id=?;
229234
""",
230235
(key,),
231236
)
232-
count = self._cursor.fetchone()[0]
237+
count = cursor.fetchone()[0]
233238
return count > 0
234239

235240
def search_by_attr(
@@ -277,7 +282,9 @@ def search_by_attr(
277282
where_clause.append("format=?")
278283
bindings.append(model_format)
279284
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
280-
self._cursor.execute(
285+
286+
cursor = self._db.conn.cursor()
287+
cursor.execute(
281288
f"""--sql
282289
SELECT config, strftime('%s',updated_at)
283290
FROM models
@@ -286,7 +293,7 @@ def search_by_attr(
286293
""",
287294
tuple(bindings),
288295
)
289-
result = self._cursor.fetchall()
296+
result = cursor.fetchall()
290297

291298
# Parse the model configs.
292299
results: list[AnyModelConfig] = []
@@ -305,26 +312,28 @@ def search_by_attr(
305312

306313
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
307314
"""Return models with the indicated path."""
308-
self._cursor.execute(
315+
cursor = self._db.conn.cursor()
316+
cursor.execute(
309317
"""--sql
310318
SELECT config, strftime('%s',updated_at) FROM models
311319
WHERE path=?;
312320
""",
313321
(str(path),),
314322
)
315-
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
323+
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
316324
return results
317325

318326
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
319327
"""Return models with the indicated hash."""
320-
self._cursor.execute(
328+
cursor = self._db.conn.cursor()
329+
cursor.execute(
321330
"""--sql
322331
SELECT config, strftime('%s',updated_at) FROM models
323332
WHERE hash=?;
324333
""",
325334
(hash,),
326335
)
327-
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
336+
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
328337
return results
329338

330339
def list_models(
@@ -340,18 +349,20 @@ def list_models(
340349
ModelRecordOrderBy.Format: "format",
341350
}
342351

352+
cursor = self._db.conn.cursor()
353+
343354
# Lock so that the database isn't updated while we're doing the two queries.
344355
# query1: get the total number of model configs
345-
self._cursor.execute(
356+
cursor.execute(
346357
"""--sql
347358
select count(*) from models;
348359
""",
349360
(),
350361
)
351-
total = int(self._cursor.fetchone()[0])
362+
total = int(cursor.fetchone()[0])
352363

353364
# query2: fetch key fields
354-
self._cursor.execute(
365+
cursor.execute(
355366
f"""--sql
356367
SELECT config
357368
FROM models
@@ -364,6 +375,6 @@ def list_models(
364375
page * per_page,
365376
),
366377
)
367-
rows = self._cursor.fetchall()
378+
rows = cursor.fetchall()
368379
items = [ModelSummary.model_validate(dict(x)) for x in rows]
369380
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)

0 commit comments

Comments
 (0)