@@ -132,42 +132,116 @@ async def check_key_in_limits_v2(
132
132
):
133
133
## INCREMENT CURRENT USAGE
134
134
increment_list : List [Tuple [str , int ]] = []
135
+ decrement_list : List [Tuple [str , int ]] = []
136
+ slots_to_check : List [str ] = []
135
137
increment_value_by_group = {
136
138
"request_count" : 1 ,
137
139
"tpm" : 0 ,
138
140
"rpm" : 1 ,
139
141
}
140
- for group in ["request_count" , "rpm" , "tpm" ]:
141
- key = self ._get_current_usage_key (
142
- user_api_key_dict = user_api_key_dict ,
143
- precise_minute = precise_minute ,
144
- model = data .get ("model" , None ),
145
- rate_limit_type = rate_limit_type ,
146
- group = cast (RateLimitGroups , group ),
147
- )
148
- if key is None :
149
- continue
150
- increment_list .append ((key , increment_value_by_group [group ]))
142
+
143
+ # Get current time and calculate the last 4 15s slots
144
+ current_time = datetime .now ()
145
+ current_slot = (
146
+ current_time .second // 15
147
+ ) # This gives us 0-3 for the current 15s slot
148
+ slots_to_check = []
149
+ slot_cache_keys = []
150
+ # Calculate the last 4 slots, handling minute boundaries
151
+ for i in range (4 ):
152
+ slot_number = (current_slot - i ) % 4 # This ensures we wrap around properly
153
+ minute = current_time .minute
154
+ hour = current_time .hour
155
+
156
+ # If we need to look at previous minute
157
+ if current_slot - i < 0 :
158
+ if minute == 0 :
159
+ # If we're at minute 0, go to previous hour
160
+ hour = (current_time .hour - 1 ) % 24
161
+ minute = 59
162
+ else :
163
+ minute = current_time .minute - 1
164
+
165
+ slot_key = f"{ current_time .strftime ('%Y-%m-%d' )} -{ hour :02d} -{ minute :02d} -{ slot_number } "
166
+ slots_to_check .append (slot_key )
167
+
168
+ # For each slot, create keys for all rate limit groups
169
+ for slot_key in slots_to_check :
170
+ for group in ["request_count" , "rpm" , "tpm" ]:
171
+ key = self ._get_current_usage_key (
172
+ user_api_key_dict = user_api_key_dict ,
173
+ precise_minute = slot_key ,
174
+ model = data .get ("model" , None ),
175
+ rate_limit_type = rate_limit_type ,
176
+ group = cast (RateLimitGroups , group ),
177
+ )
178
+ if key is None :
179
+ continue
180
+ # Only increment the current slot
181
+ if slot_key == slots_to_check [0 ]:
182
+ increment_list .append ((key , increment_value_by_group [group ]))
183
+ decrement_list .append (
184
+ (key , - 1 if increment_value_by_group [group ] == 1 else 0 )
185
+ )
186
+ slot_cache_keys .append (key )
151
187
152
188
if (
153
189
not max_parallel_requests and not rpm_limit and not tpm_limit
154
190
): # no rate limits
155
191
return
156
192
157
- results = await self ._increment_value_list_in_current_window (
193
+ # Use the existing atomic increment-and-check functionality
194
+ await self ._increment_value_list_in_current_window (
158
195
increment_list = increment_list ,
159
196
ttl = 60 ,
160
197
)
198
+
199
+ # Get the current values for all slots to check limits
200
+ current_values = await self .internal_usage_cache .async_batch_get_cache (
201
+ slot_cache_keys
202
+ )
203
+ if current_values is None :
204
+ current_values = [None ] * len (slot_cache_keys )
205
+
206
+ # Calculate totals across all slots, handling None values
207
+ # Group values by type (request_count, rpm, tpm)
208
+ request_counts = []
209
+ rpm_counts = []
210
+ tpm_counts = []
211
+
212
+ for i in range (0 , len (current_values ), 3 ):
213
+ request_counts .append (
214
+ current_values [i ] if current_values [i ] is not None else 0
215
+ )
216
+ rpm_counts .append (
217
+ current_values [i + 1 ] if current_values [i + 1 ] is not None else 0
218
+ )
219
+ tpm_counts .append (
220
+ current_values [i + 2 ] if current_values [i + 2 ] is not None else 0
221
+ )
222
+
223
+ # Calculate totals across all slots
224
+ total_requests = sum (request_counts )
225
+ total_rpm = sum (rpm_counts )
226
+ total_tpm = sum (tpm_counts )
227
+
161
228
should_raise_error = False
162
229
if max_parallel_requests is not None :
163
- should_raise_error = results [ 0 ] > max_parallel_requests
230
+ should_raise_error = total_requests > max_parallel_requests
164
231
if rpm_limit is not None :
165
- should_raise_error = should_raise_error or results [ 1 ] > rpm_limit
232
+ should_raise_error = should_raise_error or total_rpm > rpm_limit
166
233
if tpm_limit is not None :
167
- should_raise_error = should_raise_error or results [2 ] > tpm_limit
234
+ should_raise_error = should_raise_error or total_tpm > tpm_limit
235
+
168
236
if should_raise_error :
237
+ ## DECREMENT CURRENT USAGE - so we don't keep failing subsequent requests
238
+ await self ._increment_value_list_in_current_window (
239
+ increment_list = decrement_list ,
240
+ ttl = 60 ,
241
+ )
242
+
169
243
raise self .raise_rate_limit_error (
170
- additional_details = f"{ CommonProxyErrors .max_parallel_request_limit_reached .value } . Hit limit for { rate_limit_type } . Current usage: max_parallel_requests: { results [ 0 ] } , current_rpm: { results [ 1 ] } , current_tpm: { results [ 2 ] } . Current limits: max_parallel_requests: { max_parallel_requests } , rpm_limit: { rpm_limit } , tpm_limit: { tpm_limit } ."
244
+ additional_details = f"{ CommonProxyErrors .max_parallel_request_limit_reached .value } . Hit limit for { rate_limit_type } . Current usage: max_parallel_requests: { total_requests } , current_rpm: { total_rpm } , current_tpm: { total_tpm } . Current limits: max_parallel_requests: { max_parallel_requests } , rpm_limit: { rpm_limit } , tpm_limit: { tpm_limit } ."
171
245
)
172
246
173
247
def time_to_next_minute (self ) -> float :
@@ -356,11 +430,18 @@ async def _update_usage_in_cache_post_call(
356
430
}
357
431
358
432
rate_limit_types = ["key" , "user" , "customer" , "team" , "model_per_key" ]
433
+ current_time = datetime .now ()
434
+ current_hour = current_time .hour
435
+ current_minute = current_time .minute
436
+ current_slot = (
437
+ current_time .second // 15
438
+ ) # This gives us 0-3 for the current 15s slot
439
+ slot_key = f"{ current_time .strftime ('%Y-%m-%d' )} -{ current_hour :02d} -{ current_minute :02d} -{ current_slot } "
359
440
for rate_limit_type in rate_limit_types :
360
441
for group in ["request_count" , "rpm" , "tpm" ]:
361
442
key = self ._get_current_usage_key (
362
443
user_api_key_dict = user_api_key_dict ,
363
- precise_minute = precise_minute ,
444
+ precise_minute = slot_key ,
364
445
model = model ,
365
446
rate_limit_type = cast (RateLimitTypes , rate_limit_type ),
366
447
group = cast (RateLimitGroups , group ),
0 commit comments