@@ -136,16 +136,22 @@ def set(self, name: str, value, ex: None | int | timedelta = None) -> None:
136
136
expire_time = datetime .now () + expire
137
137
138
138
try :
139
- cache_item = self .db .session .query (Cache ).filter (Cache .cache_key == name ).first ()
140
- if cache_item :
141
- cache_item .cache_value = value
142
- cache_item .expire_time = expire_time
143
- else :
144
- cache_item = Cache ()
145
- cache_item .cache_key = name
146
- cache_item .cache_value = value
147
- cache_item .expire_time = expire_time
148
- self .db .session .add (cache_item )
139
+ # 使用 INSERT ... ON DUPLICATE KEY UPDATE 避免竞态条件
140
+ sql = """
141
+ INSERT INTO caches (cache_key, cache_value, expire_time)
142
+ VALUES (:cache_key, :cache_value, :expire_time)
143
+ ON DUPLICATE KEY UPDATE
144
+ cache_value = VALUES(cache_value),
145
+ expire_time = VALUES(expire_time)
146
+ """
147
+ self .db .session .execute (
148
+ db .text (sql ),
149
+ {
150
+ 'cache_key' : name ,
151
+ 'cache_value' : value ,
152
+ 'expire_time' : expire_time
153
+ }
154
+ )
149
155
self .db .session .commit ()
150
156
except Exception as e :
151
157
logger .warning ("MySQLRedisClient.set " + str (name ) + " got exception: " + str (e ))
@@ -162,16 +168,22 @@ def setex(self, name: str, time: int | timedelta, value) -> None:
162
168
expire_time = datetime .now () + expire
163
169
164
170
try :
165
- cache_item = self .db .session .query (Cache ).filter (Cache .cache_key == name ).first ()
166
- if cache_item :
167
- cache_item .cache_value = value
168
- cache_item .expire_time = expire_time
169
- else :
170
- cache_item = Cache ()
171
- cache_item .cache_key = name
172
- cache_item .cache_value = value
173
- cache_item .expire_time = expire_time
174
- self .db .session .add (cache_item )
171
+ # 使用 INSERT ... ON DUPLICATE KEY UPDATE 避免竞态条件
172
+ sql = """
173
+ INSERT INTO caches (cache_key, cache_value, expire_time)
174
+ VALUES (:cache_key, :cache_value, :expire_time)
175
+ ON DUPLICATE KEY UPDATE
176
+ cache_value = VALUES(cache_value),
177
+ expire_time = VALUES(expire_time)
178
+ """
179
+ self .db .session .execute (
180
+ db .text (sql ),
181
+ {
182
+ 'cache_key' : name ,
183
+ 'cache_value' : value ,
184
+ 'expire_time' : expire_time
185
+ }
186
+ )
175
187
self .db .session .commit ()
176
188
except Exception as e :
177
189
logger .warning ("MySQLRedisClient.setex " + str (name ) + " got exception: " + str (e ))
@@ -185,15 +197,19 @@ def setnx(self, name: str, value) -> None:
185
197
value = (str (value )).encode ('utf-8' )
186
198
187
199
try :
188
- cache_item = self .db .session .query (Cache ).filter (Cache .cache_key == name ).first ()
189
- if cache_item :
190
- return
191
-
192
- cache_item = Cache ()
193
- cache_item .cache_key = name
194
- cache_item .cache_value = value
195
- cache_item .expire_time = None
196
- self .db .session .add (cache_item )
200
+ # 使用 INSERT IGNORE 避免竞态条件,仅在不存在时插入
201
+ sql = """
202
+ INSERT IGNORE INTO caches (cache_key, cache_value, expire_time)
203
+ VALUES (:cache_key, :cache_value, :expire_time)
204
+ """
205
+ self .db .session .execute (
206
+ db .text (sql ),
207
+ {
208
+ 'cache_key' : name ,
209
+ 'cache_value' : value ,
210
+ 'expire_time' : None
211
+ }
212
+ )
197
213
self .db .session .commit ()
198
214
except Exception as e :
199
215
logger .warning ("MySQLRedisClient.setnx " + str (name ) + " got exception: " + str (e ))
@@ -215,24 +231,34 @@ def incr(self, name: str, amount: int = 1) -> bytes:
215
231
return b'0'
216
232
217
233
try :
218
- cache_item = self .db .session .query (Cache ).filter (Cache .cache_key == name ).first ()
219
-
220
- if cache_item :
221
- current_value = int (cache_item .cache_value .decode ('utf-8' ))
234
+ # 使用事务确保原子性,避免并发问题
235
+ with self .db .session .begin ():
236
+ # 1. 获取当前值(在事务中)
237
+ current_item = self .db .session .query (Cache ).filter (Cache .cache_key == name ).first ()
238
+ current_value = 0
239
+ if current_item :
240
+ try :
241
+ current_value = int (current_item .cache_value .decode ('utf-8' ))
242
+ except (ValueError , UnicodeDecodeError ):
243
+ current_value = 0
244
+
222
245
new_value = current_value + amount
223
- cache_item .cache_value = str (new_value ).encode ('utf-8' )
224
- else :
225
- cache_item = Cache ()
226
- cache_item .cache_key = name
227
- cache_item .cache_value = str (amount ).encode ('utf-8' )
228
- cache_item .expire_time = None
229
- self .db .session .add (cache_item )
230
-
231
- self .db .session .commit ()
232
- return cache_item .cache_value
246
+
247
+ # 2. 原子更新(在同一个事务中)
248
+ if current_item :
249
+ current_item .cache_value = str (new_value ).encode ('utf-8' )
250
+ else :
251
+ cache_item = Cache ()
252
+ cache_item .cache_key = name
253
+ cache_item .cache_value = str (new_value ).encode ('utf-8' )
254
+ cache_item .expire_time = None
255
+ self .db .session .add (cache_item )
256
+
257
+ # 3. 事务自动提交,返回结果
258
+ return str (new_value ).encode ('utf-8' )
259
+
233
260
except Exception as e :
234
261
logger .warning ("MySQLRedisClient.incr " + str (name ) + " got exception: " + str (e ))
235
- self .db .session .rollback ()
236
262
return b'0'
237
263
238
264
def expire (self , name : str , time : int | timedelta ) -> None :
@@ -243,10 +269,20 @@ def expire(self, name: str, time: int | timedelta) -> None:
243
269
expire_time = datetime .now () + expire
244
270
245
271
try :
246
- cache_item = self .db .session .query (Cache ).filter (Cache .cache_key == name ).first ()
247
- if cache_item :
248
- cache_item .expire_time = expire_time
249
- self .db .session .commit ()
272
+ # 使用 UPDATE 语句避免竞态条件
273
+ sql = """
274
+ UPDATE caches
275
+ SET expire_time = :expire_time
276
+ WHERE cache_key = :cache_key
277
+ """
278
+ result = self .db .session .execute (
279
+ db .text (sql ),
280
+ {
281
+ 'cache_key' : name ,
282
+ 'expire_time' : expire_time
283
+ }
284
+ )
285
+ self .db .session .commit ()
250
286
except Exception as e :
251
287
logger .warning ("MySQLRedisClient.expire " + str (name ) + " got exception: " + str (e ))
252
288
self .db .session .rollback ()
@@ -258,29 +294,32 @@ def zadd(self, name: str, mapping: Mapping) -> None:
258
294
try :
259
295
import json
260
296
261
- cache_item = self .db .session .query (Cache ).filter (Cache .cache_key == name ).first ()
262
-
263
- if cache_item :
264
- try :
265
- existing_data = json .loads (cache_item .cache_value .decode ('utf-8' ))
266
- if not isinstance (existing_data , dict ):
297
+ # 使用事务确保原子性,避免并发问题
298
+ with self .db .session .begin ():
299
+ # 1. 在事务中获取现有数据
300
+ cache_item = self .db .session .query (Cache ).filter (Cache .cache_key == name ).first ()
301
+
302
+ if cache_item :
303
+ try :
304
+ existing_data = json .loads (cache_item .cache_value .decode ('utf-8' ))
305
+ if not isinstance (existing_data , dict ):
306
+ existing_data = {}
307
+ except (json .JSONDecodeError , UnicodeDecodeError ):
267
308
existing_data = {}
268
- except (json .JSONDecodeError , UnicodeDecodeError ):
269
- existing_data = {}
270
-
271
- existing_data .update (mapping )
272
- cache_item .cache_value = json .dumps (existing_data ).encode ('utf-8' )
273
- else :
274
- cache_item = Cache ()
275
- cache_item .cache_key = name
276
- cache_item .cache_value = json .dumps (dict (mapping )).encode ('utf-8' )
277
- cache_item .expire_time = None
278
- self .db .session .add (cache_item )
279
-
280
- self .db .session .commit ()
309
+
310
+ existing_data .update (mapping )
311
+ cache_item .cache_value = json .dumps (existing_data ).encode ('utf-8' )
312
+ else :
313
+ cache_item = Cache ()
314
+ cache_item .cache_key = name
315
+ cache_item .cache_value = json .dumps (dict (mapping )).encode ('utf-8' )
316
+ cache_item .expire_time = None
317
+ self .db .session .add (cache_item )
318
+
319
+ # 2. 事务自动提交
320
+
281
321
except Exception as e :
282
322
logger .warning ("MySQLRedisClient.zadd " + str (name ) + " got exception: " + str (e ))
283
- self .db .session .rollback ()
284
323
285
324
def zremrangebyscore (self , name : str , min : int | float | str , max : int | float | str ):
286
325
if not self .db :
@@ -289,39 +328,42 @@ def zremrangebyscore(self, name: str, min: int | float | str, max: int | float |
289
328
try :
290
329
import json
291
330
292
- cache_item = self .db .session .query (Cache ).filter (Cache .cache_key == name ).first ()
293
- if not cache_item :
294
- return 0
295
-
296
- try :
297
- existing_data = json .loads (cache_item .cache_value .decode ('utf-8' ))
298
- if not isinstance (existing_data , dict ):
331
+ # 使用事务确保原子性,避免并发问题
332
+ with self .db .session .begin ():
333
+ # 1. 在事务中获取现有数据
334
+ cache_item = self .db .session .query (Cache ).filter (Cache .cache_key == name ).first ()
335
+ if not cache_item :
299
336
return 0
300
- except (json .JSONDecodeError , UnicodeDecodeError ):
301
- return 0
302
337
303
- min_score = float (min ) if min != '-inf' else float ('-inf' )
304
- max_score = float (max ) if max != '+inf' else float ('inf' )
305
-
306
- members_to_remove = []
307
- for member , score in existing_data .items ():
308
338
try :
309
- score_float = float ( score )
310
- if min_score <= score_float <= max_score :
311
- members_to_remove . append ( member )
312
- except (ValueError , TypeError ):
313
- continue
339
+ existing_data = json . loads ( cache_item . cache_value . decode ( 'utf-8' ) )
340
+ if not isinstance ( existing_data , dict ) :
341
+ return 0
342
+ except (json . JSONDecodeError , UnicodeDecodeError ):
343
+ return 0
314
344
315
- for member in members_to_remove :
316
- del existing_data [ member ]
345
+ min_score = float ( min ) if min != '-inf' else float ( '-inf' )
346
+ max_score = float ( max ) if max != '+inf' else float ( 'inf' )
317
347
318
- cache_item .cache_value = json .dumps (existing_data ).encode ('utf-8' )
319
- self .db .session .commit ()
348
+ members_to_remove = []
349
+ for member , score in existing_data .items ():
350
+ try :
351
+ score_float = float (score )
352
+ if min_score <= score_float <= max_score :
353
+ members_to_remove .append (member )
354
+ except (ValueError , TypeError ):
355
+ continue
320
356
321
- return len (members_to_remove )
357
+ for member in members_to_remove :
358
+ del existing_data [member ]
359
+
360
+ cache_item .cache_value = json .dumps (existing_data ).encode ('utf-8' )
361
+
362
+ # 2. 事务自动提交,返回结果
363
+ return len (members_to_remove )
364
+
322
365
except Exception as e :
323
366
logger .warning ("MySQLRedisClient.zremrangebyscore " + str (name ) + " got exception: " + str (e ))
324
- self .db .session .rollback ()
325
367
return 0
326
368
327
369
def zcard (self , name : str ) -> int :
0 commit comments