@@ -80,6 +80,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
80
80
context .lora_file = None
81
81
context .testing_sampler_delay_millis = None
82
82
context .disable_ctx_shift = False
83
+ context .disconnect_after_millis = None
83
84
84
85
context .tasks_result = []
85
86
context .concurrent_tasks = []
@@ -279,6 +280,7 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
279
280
n_predict = context .n_predict ,
280
281
cache_prompt = context .cache_prompt ,
281
282
id_slot = context .id_slot ,
283
+ disconnect_after_millis = context .disconnect_after_millis ,
282
284
expect_api_error = expect_api_error ,
283
285
user_api_key = context .user_api_key ,
284
286
temperature = context .temperature )
@@ -296,20 +298,12 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
296
298
async def step_request_completion (context , millis : int ):
297
299
await asyncio .sleep (millis / 1000.0 )
298
300
299
- @step ('a completion request cancelled after {disconnect_after_millis:d} milliseconds' )
301
+
302
+ @step ('disconnect after {disconnect_after_millis:d} milliseconds' )
300
303
@async_run_until_complete
301
- async def step_request_completion (context , disconnect_after_millis : int ):
302
- seeds = await completions_seed (context , num_seeds = 1 )
303
- await request_completion (context .prompts .pop (),
304
- seeds [0 ] if seeds is not None else seeds ,
305
- context .base_url ,
306
- debug = context .debug ,
307
- n_predict = context .n_predict ,
308
- cache_prompt = context .cache_prompt ,
309
- id_slot = context .id_slot ,
310
- disconnect_after_millis = disconnect_after_millis ,
311
- user_api_key = context .user_api_key ,
312
- temperature = context .temperature )
304
+ async def step_disconnect_after (context , disconnect_after_millis : int ):
305
+ context .disconnect_after_millis = disconnect_after_millis
306
+
313
307
314
308
@step ('{predicted_n:d} tokens are predicted matching {re_content}' )
315
309
def step_n_tokens_predicted_with_content (context , predicted_n , re_content ):
@@ -519,6 +513,7 @@ async def step_oai_chat_completions(context, api_error):
519
513
print (f"Submitting OAI compatible completions request..." )
520
514
expect_api_error = api_error == 'raised'
521
515
seeds = await completions_seed (context , num_seeds = 1 ),
516
+ seeds = await completions_seed (context , num_seeds = 1 )
522
517
completion = await oai_chat_completions (context .prompts .pop (),
523
518
seeds [0 ] if seeds is not None else seeds ,
524
519
context .system_prompt ,
@@ -539,6 +534,8 @@ async def step_oai_chat_completions(context, api_error):
539
534
user_api_key = context .user_api_key
540
535
if hasattr (context , 'user_api_key' ) else None ,
541
536
537
+ disconnect_after_millis = context .disconnect_after_millis ,
538
+
542
539
expect_api_error = expect_api_error )
543
540
context .tasks_result .append (completion )
544
541
if context .debug :
@@ -606,6 +603,7 @@ async def step_oai_chat_completions(context):
606
603
if hasattr (context , 'enable_streaming' ) else None ,
607
604
response_format = context .response_format
608
605
if hasattr (context , 'response_format' ) else None ,
606
+ disconnect_after_millis = context .disconnect_after_millis ,
609
607
user_api_key = context .user_api_key
610
608
if hasattr (context , 'user_api_key' ) else None )
611
609
@@ -1029,9 +1027,9 @@ async def request_completion(prompt,
1029
1027
},
1030
1028
headers = headers ) as response :
1031
1029
if disconnect_after_millis is not None :
1032
- await asyncio .sleep (disconnect_after_millis / 1000 )
1030
+ await asyncio .sleep (disconnect_after_millis / 1000.0 )
1033
1031
return 0
1034
-
1032
+
1035
1033
if expect_api_error is None or not expect_api_error :
1036
1034
assert response .status == 200
1037
1035
assert response .headers ['Access-Control-Allow-Origin' ] == origin
@@ -1050,6 +1048,7 @@ async def oai_chat_completions(user_prompt,
1050
1048
temperature = None ,
1051
1049
model = None ,
1052
1050
n_predict = None ,
1051
+ disconnect_after_millis = None ,
1053
1052
enable_streaming = None ,
1054
1053
response_format = None ,
1055
1054
user_api_key = None ,
@@ -1093,6 +1092,10 @@ async def oai_chat_completions(user_prompt,
1093
1092
async with session .post (f'{ base_url } { base_path } ' ,
1094
1093
json = payload ,
1095
1094
headers = headers ) as response :
1095
+ if disconnect_after_millis is not None :
1096
+ await asyncio .sleep (disconnect_after_millis / 1000.0 )
1097
+ return 0
1098
+
1096
1099
if enable_streaming :
1097
1100
assert response .status == 200
1098
1101
assert response .headers ['Access-Control-Allow-Origin' ] == origin
@@ -1133,6 +1136,7 @@ async def oai_chat_completions(user_prompt,
1133
1136
else :
1134
1137
return response .status
1135
1138
else :
1139
+ assert disconnect_after_millis is None , "disconnect_after_millis is not supported with sync client"
1136
1140
try :
1137
1141
openai .api_key = user_api_key
1138
1142
openai .base_url = f'{ base_url } { base_path .removesuffix ("chat" )} '
0 commit comments