@@ -78,7 +78,6 @@ def __init__(self, db: SqliteDatabase, logger: logging.Logger):
78
78
"""
79
79
super ().__init__ ()
80
80
self ._db = db
81
- self ._cursor = db .conn .cursor ()
82
81
self ._logger = logger
83
82
84
83
@property
@@ -97,7 +96,8 @@ def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
97
96
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
98
97
"""
99
98
try :
100
- self ._cursor .execute (
99
+ cursor = self ._db .conn .cursor ()
100
+ cursor .execute (
101
101
"""--sql
102
102
INSERT INTO models (
103
103
id,
@@ -139,14 +139,15 @@ def del_model(self, key: str) -> None:
139
139
Can raise an UnknownModelException
140
140
"""
141
141
try :
142
- self ._cursor .execute (
142
+ cursor = self ._db .conn .cursor ()
143
+ cursor .execute (
143
144
"""--sql
144
145
DELETE FROM models
145
146
WHERE id=?;
146
147
""" ,
147
148
(key ,),
148
149
)
149
- if self . _cursor .rowcount == 0 :
150
+ if cursor .rowcount == 0 :
150
151
raise UnknownModelException ("model not found" )
151
152
self ._db .conn .commit ()
152
153
except sqlite3 .Error as e :
@@ -163,7 +164,8 @@ def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
163
164
json_serialized = record .model_dump_json ()
164
165
165
166
try :
166
- self ._cursor .execute (
167
+ cursor = self ._db .conn .cursor ()
168
+ cursor .execute (
167
169
"""--sql
168
170
UPDATE models
169
171
SET
@@ -172,7 +174,7 @@ def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
172
174
""" ,
173
175
(json_serialized , key ),
174
176
)
175
- if self . _cursor .rowcount == 0 :
177
+ if cursor .rowcount == 0 :
176
178
raise UnknownModelException ("model not found" )
177
179
self ._db .conn .commit ()
178
180
except sqlite3 .Error as e :
@@ -189,28 +191,30 @@ def get_model(self, key: str) -> AnyModelConfig:
189
191
190
192
Exceptions: UnknownModelException
191
193
"""
192
- self ._cursor .execute (
194
+ cursor = self ._db .conn .cursor ()
195
+ cursor .execute (
193
196
"""--sql
194
197
SELECT config, strftime('%s',updated_at) FROM models
195
198
WHERE id=?;
196
199
""" ,
197
200
(key ,),
198
201
)
199
- rows = self . _cursor .fetchone ()
202
+ rows = cursor .fetchone ()
200
203
if not rows :
201
204
raise UnknownModelException ("model not found" )
202
205
model = ModelConfigFactory .make_config (json .loads (rows [0 ]), timestamp = rows [1 ])
203
206
return model
204
207
205
208
def get_model_by_hash (self , hash : str ) -> AnyModelConfig :
206
- self ._cursor .execute (
209
+ cursor = self ._db .conn .cursor ()
210
+ cursor .execute (
207
211
"""--sql
208
212
SELECT config, strftime('%s',updated_at) FROM models
209
213
WHERE hash=?;
210
214
""" ,
211
215
(hash ,),
212
216
)
213
- rows = self . _cursor .fetchone ()
217
+ rows = cursor .fetchone ()
214
218
if not rows :
215
219
raise UnknownModelException ("model not found" )
216
220
model = ModelConfigFactory .make_config (json .loads (rows [0 ]), timestamp = rows [1 ])
@@ -222,14 +226,15 @@ def exists(self, key: str) -> bool:
222
226
223
227
:param key: Unique key for the model to be deleted
224
228
"""
225
- self ._cursor .execute (
229
+ cursor = self ._db .conn .cursor ()
230
+ cursor .execute (
226
231
"""--sql
227
232
select count(*) FROM models
228
233
WHERE id=?;
229
234
""" ,
230
235
(key ,),
231
236
)
232
- count = self . _cursor .fetchone ()[0 ]
237
+ count = cursor .fetchone ()[0 ]
233
238
return count > 0
234
239
235
240
def search_by_attr (
@@ -277,7 +282,9 @@ def search_by_attr(
277
282
where_clause .append ("format=?" )
278
283
bindings .append (model_format )
279
284
where = f"WHERE { ' AND ' .join (where_clause )} " if where_clause else ""
280
- self ._cursor .execute (
285
+
286
+ cursor = self ._db .conn .cursor ()
287
+ cursor .execute (
281
288
f"""--sql
282
289
SELECT config, strftime('%s',updated_at)
283
290
FROM models
@@ -286,7 +293,7 @@ def search_by_attr(
286
293
""" ,
287
294
tuple (bindings ),
288
295
)
289
- result = self . _cursor .fetchall ()
296
+ result = cursor .fetchall ()
290
297
291
298
# Parse the model configs.
292
299
results : list [AnyModelConfig ] = []
@@ -305,26 +312,28 @@ def search_by_attr(
305
312
306
313
def search_by_path (self , path : Union [str , Path ]) -> List [AnyModelConfig ]:
307
314
"""Return models with the indicated path."""
308
- self ._cursor .execute (
315
+ cursor = self ._db .conn .cursor ()
316
+ cursor .execute (
309
317
"""--sql
310
318
SELECT config, strftime('%s',updated_at) FROM models
311
319
WHERE path=?;
312
320
""" ,
313
321
(str (path ),),
314
322
)
315
- results = [ModelConfigFactory .make_config (json .loads (x [0 ]), timestamp = x [1 ]) for x in self . _cursor .fetchall ()]
323
+ results = [ModelConfigFactory .make_config (json .loads (x [0 ]), timestamp = x [1 ]) for x in cursor .fetchall ()]
316
324
return results
317
325
318
326
def search_by_hash (self , hash : str ) -> List [AnyModelConfig ]:
319
327
"""Return models with the indicated hash."""
320
- self ._cursor .execute (
328
+ cursor = self ._db .conn .cursor ()
329
+ cursor .execute (
321
330
"""--sql
322
331
SELECT config, strftime('%s',updated_at) FROM models
323
332
WHERE hash=?;
324
333
""" ,
325
334
(hash ,),
326
335
)
327
- results = [ModelConfigFactory .make_config (json .loads (x [0 ]), timestamp = x [1 ]) for x in self . _cursor .fetchall ()]
336
+ results = [ModelConfigFactory .make_config (json .loads (x [0 ]), timestamp = x [1 ]) for x in cursor .fetchall ()]
328
337
return results
329
338
330
339
def list_models (
@@ -340,18 +349,20 @@ def list_models(
340
349
ModelRecordOrderBy .Format : "format" ,
341
350
}
342
351
352
+ cursor = self ._db .conn .cursor ()
353
+
343
354
# Lock so that the database isn't updated while we're doing the two queries.
344
355
# query1: get the total number of model configs
345
- self . _cursor .execute (
356
+ cursor .execute (
346
357
"""--sql
347
358
select count(*) from models;
348
359
""" ,
349
360
(),
350
361
)
351
- total = int (self . _cursor .fetchone ()[0 ])
362
+ total = int (cursor .fetchone ()[0 ])
352
363
353
364
# query2: fetch key fields
354
- self . _cursor .execute (
365
+ cursor .execute (
355
366
f"""--sql
356
367
SELECT config
357
368
FROM models
@@ -364,6 +375,6 @@ def list_models(
364
375
page * per_page ,
365
376
),
366
377
)
367
- rows = self . _cursor .fetchall ()
378
+ rows = cursor .fetchall ()
368
379
items = [ModelSummary .model_validate (dict (x )) for x in rows ]
369
380
return PaginatedResults (page = page , pages = ceil (total / per_page ), per_page = per_page , total = total , items = items )
0 commit comments