Skip to content

Commit 3f96ab0

Browse files
author
ochafik
committed
server: fix cancel tests
1 parent 88c9b54 commit 3f96ab0

File tree

3 files changed

+53
-32
lines changed

3 files changed

+53
-32
lines changed

examples/server/server.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2349,6 +2349,7 @@ struct server_context {
23492349

23502350
completion_token_output result;
23512351
if (params.testing_sampler_delay_millis > 0) {
2352+
LOG_DBG("sleeping for %dms before sampling (for tests!)\n", params.testing_sampler_delay_millis);
23522353
std::this_thread::sleep_for(std::chrono::milliseconds(params.testing_sampler_delay_millis));
23532354
}
23542355
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
@@ -3006,7 +3007,7 @@ int main(int argc, char ** argv) {
30063007
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
30073008

30083009
bool stream = json_value(data, "stream", false);
3009-
3010+
30103011
handle_tasks(stream, res, ctx_server, [data, &ctx_server](const std::function<bool()> & is_alive) {
30113012
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, is_alive);
30123013
ctx_server.queue_results.add_waiting_tasks(tasks);
@@ -3136,7 +3137,7 @@ int main(int argc, char ** argv) {
31363137
return;
31373138
}
31383139

3139-
3140+
31403141
handle_tasks(false, res, ctx_server, [prompt, &ctx_server](const std::function<bool()> & is_alive) {
31413142
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, is_alive);
31423143
ctx_server.queue_results.add_waiting_tasks(tasks);
@@ -3164,7 +3165,7 @@ int main(int argc, char ** argv) {
31643165
json root = is_openai
31653166
? format_embeddings_response_oaicompat(body, responses)
31663167
: responses[0];
3167-
3168+
31683169
res_ok(res, &sink, root);
31693170
});
31703171
};

examples/server/tests/features/cancel.feature

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ Feature: Cancellation of llama.cpp server requests
44

55
Background: Server startup
66
Given a server listening on localhost:8080
7-
And 500 milliseconds delay in sampler for testing
87
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
98
And a model file test-model.gguf
109
And a model alias tinyllama-2
@@ -13,28 +12,45 @@ Feature: Cancellation of llama.cpp server requests
1312
# KV Cache corresponds to the total amount of tokens
1413
# that can be stored across all independent sequences: #4130
1514
# see --ctx-size and #5568
16-
And 256 KV cache size
15+
And 512 KV cache size
1716
And 32 as batch size
18-
And 1 slots
17+
And 2 slots
1918
And 64 server max tokens to predict
19+
And prometheus compatible metrics exposed
20+
And 300 milliseconds delay in sampler for testing
21+
And no warmup
2022
Then the server is starting
2123
Then the server is healthy
24+
# Then the server is healthy with timeout 10 seconds
2225

23-
# Scenario: Health
24-
# Then the server is ready
25-
# And all slots are idle
2626

27-
@wip
28-
Scenario Outline: Cancelling completion request frees up slot
29-
Given a prompt:
30-
"""
31-
Once upon
32-
"""
27+
Scenario Outline: Cancelling an OAI chat completion request frees up slot (streaming <enable_streaming>)
28+
Given a model llama-2
29+
And a user prompt Once upon a time
30+
And a system prompt You tell lengthy stories
3331
And 256 max tokens to predict
3432
And 256 server max tokens to predict
3533
And streaming is <enable_streaming>
36-
And a completion request cancelled after 100 milliseconds
37-
# And wait for 50 milliseconds
34+
And disconnect after 100 milliseconds
35+
Given concurrent OAI completions requests
36+
And wait for 700 milliseconds
37+
Then all slots are idle
38+
39+
Examples: Prompts
40+
| enable_streaming |
41+
| disabled |
42+
| enabled |
43+
44+
45+
Scenario Outline: Cancelling a completion request frees up slot (streaming <enable_streaming>)
46+
Given a model llama-2
47+
Given a prompt Once upon a time
48+
And 256 max tokens to predict
49+
And 256 server max tokens to predict
50+
And streaming is <enable_streaming>
51+
And disconnect after 100 milliseconds
52+
Given a completion request with no api error
53+
And wait for 700 milliseconds
3854
Then all slots are idle
3955

4056
Examples: Prompts

examples/server/tests/features/steps/steps.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
8080
context.lora_file = None
8181
context.testing_sampler_delay_millis = None
8282
context.disable_ctx_shift = False
83+
context.disconnect_after_millis = None
8384

8485
context.tasks_result = []
8586
context.concurrent_tasks = []
@@ -279,6 +280,7 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
279280
n_predict=context.n_predict,
280281
cache_prompt=context.cache_prompt,
281282
id_slot=context.id_slot,
283+
disconnect_after_millis=context.disconnect_after_millis,
282284
expect_api_error=expect_api_error,
283285
user_api_key=context.user_api_key,
284286
temperature=context.temperature)
@@ -296,20 +298,12 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
296298
async def step_request_completion(context, millis: int):
297299
await asyncio.sleep(millis / 1000.0)
298300

299-
@step('a completion request cancelled after {disconnect_after_millis:d} milliseconds')
301+
302+
@step('disconnect after {disconnect_after_millis:d} milliseconds')
300303
@async_run_until_complete
301-
async def step_request_completion(context, disconnect_after_millis: int):
302-
seeds = await completions_seed(context, num_seeds=1)
303-
await request_completion(context.prompts.pop(),
304-
seeds[0] if seeds is not None else seeds,
305-
context.base_url,
306-
debug=context.debug,
307-
n_predict=context.n_predict,
308-
cache_prompt=context.cache_prompt,
309-
id_slot=context.id_slot,
310-
disconnect_after_millis=disconnect_after_millis,
311-
user_api_key=context.user_api_key,
312-
temperature=context.temperature)
304+
async def step_disconnect_after(context, disconnect_after_millis: int):
305+
context.disconnect_after_millis = disconnect_after_millis
306+
313307

314308
@step('{predicted_n:d} tokens are predicted matching {re_content}')
315309
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
@@ -519,6 +513,7 @@ async def step_oai_chat_completions(context, api_error):
519513
print(f"Submitting OAI compatible completions request...")
520514
expect_api_error = api_error == 'raised'
521515
seeds = await completions_seed(context, num_seeds=1),
516+
seeds = await completions_seed(context, num_seeds=1)
522517
completion = await oai_chat_completions(context.prompts.pop(),
523518
seeds[0] if seeds is not None else seeds,
524519
context.system_prompt,
@@ -539,6 +534,8 @@ async def step_oai_chat_completions(context, api_error):
539534
user_api_key=context.user_api_key
540535
if hasattr(context, 'user_api_key') else None,
541536

537+
disconnect_after_millis=context.disconnect_after_millis,
538+
542539
expect_api_error=expect_api_error)
543540
context.tasks_result.append(completion)
544541
if context.debug:
@@ -606,6 +603,7 @@ async def step_oai_chat_completions(context):
606603
if hasattr(context, 'enable_streaming') else None,
607604
response_format=context.response_format
608605
if hasattr(context, 'response_format') else None,
606+
disconnect_after_millis=context.disconnect_after_millis,
609607
user_api_key=context.user_api_key
610608
if hasattr(context, 'user_api_key') else None)
611609

@@ -1029,9 +1027,9 @@ async def request_completion(prompt,
10291027
},
10301028
headers=headers) as response:
10311029
if disconnect_after_millis is not None:
1032-
await asyncio.sleep(disconnect_after_millis / 1000)
1030+
await asyncio.sleep(disconnect_after_millis / 1000.0)
10331031
return 0
1034-
1032+
10351033
if expect_api_error is None or not expect_api_error:
10361034
assert response.status == 200
10371035
assert response.headers['Access-Control-Allow-Origin'] == origin
@@ -1050,6 +1048,7 @@ async def oai_chat_completions(user_prompt,
10501048
temperature=None,
10511049
model=None,
10521050
n_predict=None,
1051+
disconnect_after_millis=None,
10531052
enable_streaming=None,
10541053
response_format=None,
10551054
user_api_key=None,
@@ -1093,6 +1092,10 @@ async def oai_chat_completions(user_prompt,
10931092
async with session.post(f'{base_url}{base_path}',
10941093
json=payload,
10951094
headers=headers) as response:
1095+
if disconnect_after_millis is not None:
1096+
await asyncio.sleep(disconnect_after_millis / 1000.0)
1097+
return 0
1098+
10961099
if enable_streaming:
10971100
assert response.status == 200
10981101
assert response.headers['Access-Control-Allow-Origin'] == origin
@@ -1133,6 +1136,7 @@ async def oai_chat_completions(user_prompt,
11331136
else:
11341137
return response.status
11351138
else:
1139+
assert disconnect_after_millis is None, "disconnect_after_millis is not supported with sync client"
11361140
try:
11371141
openai.api_key = user_api_key
11381142
openai.base_url = f'{base_url}{base_path.removesuffix("chat")}'

0 commit comments

Comments
 (0)