Skip to content

Commit 2d65d56

Browse files
authored
[CB] Add scheduling tests (#329)
This PR adds a scheduling steps test where new prompts are joining during the decode of other sequences, when there is still room left in the batch for new sequences. Execution was tested on AIU as well (passing) --------- Signed-off-by: Sophie du Couédic <sop@zurich.ibm.com>
1 parent 15d3587 commit 2d65d56

File tree

2 files changed

+280
-0
lines changed

2 files changed

+280
-0
lines changed

tests/e2e/test_spyre_cb_scheduler_steps.py

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
3333
steps_add_reqs = [0, 0, 0] # add all requests in the beginning
3434
available_blocks = -1 # no restriction
3535
max_num_seqs = 2
36+
max_model_len = 256
3637

3738
checked_steps = [
3839
{
@@ -170,6 +171,7 @@ def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
170171
steps_add_reqs=steps_add_reqs,
171172
checked_steps=checked_steps,
172173
max_num_seqs=max_num_seqs,
174+
max_model_len=max_model_len,
173175
available_blocks=available_blocks,
174176
use_cb=True,
175177
)
@@ -197,6 +199,7 @@ def test_prompts_misaligned_with_tkv_boundaries(
197199
steps_add_reqs = [0, 0, 0] # add all requests in the beginning
198200
available_blocks = -1 # no restriction
199201
max_num_seqs = 2
202+
max_model_len = 256
200203

201204
checked_steps = [
202205
{
@@ -332,6 +335,7 @@ def test_prompts_misaligned_with_tkv_boundaries(
332335
steps_add_reqs=steps_add_reqs,
333336
checked_steps=checked_steps,
334337
max_num_seqs=max_num_seqs,
338+
max_model_len=max_model_len,
335339
available_blocks=available_blocks,
336340
use_cb=True,
337341
)
@@ -358,6 +362,7 @@ def test_two_sequences_finish_same_time_as_new_arrive(
358362
steps_add_reqs = [0, 0, 31]
359363
available_blocks = -1 # no restriction
360364
max_num_seqs = 2
365+
max_model_len = 256
361366

362367
checked_steps = [
363368
{
@@ -470,6 +475,270 @@ def test_two_sequences_finish_same_time_as_new_arrive(
470475
steps_add_reqs=steps_add_reqs,
471476
checked_steps=checked_steps,
472477
max_num_seqs=max_num_seqs,
478+
max_model_len=max_model_len,
479+
available_blocks=available_blocks,
480+
use_cb=True,
481+
)
482+
483+
484+
@pytest.mark.cb
485+
@pytest.mark.parametrize("model", get_spyre_model_list())
486+
@pytest.mark.parametrize("backend", get_spyre_backend_list())
487+
def test_new_sequence_joins_during_decode(model: str, backend: str,
488+
monkeypatch: pytest.MonkeyPatch):
489+
""" Scenario where a new sequence joins while decoding other sequences
490+
491+
Configuration:
492+
* max_num_seqs: 4
493+
* number of prompts: 4
494+
* 1: len = 49, max tokens = 119, step joining = 0
495+
* 2: len = 14, max tokens = 52, step joining = 0
496+
* 3: len = 89, max tokens = 104, step joining = 32
497+
* 4: len = 9, max tokens = 64, step joining = 131
498+
"""
499+
# TODO change to 65 max_tokens for last prompt if ever possible
500+
501+
seqs_max_tokens = [119, 52, 104, 64]
502+
prompts_lengths = [49, 14, 89, 9]
503+
steps_add_reqs = [0, 0, 32, 131]
504+
available_blocks = -1 # no restriction
505+
max_num_seqs = 4
506+
max_model_len = 256
507+
508+
checked_steps = [
509+
{
510+
"step": 0,
511+
"tkv": 0,
512+
"waiting": ["0", "1"],
513+
"running": [],
514+
"request_outputs": [],
515+
"n_reserved_blocks": 0,
516+
"n_used_blocks": 0
517+
},
518+
{
519+
# Prefill sequence 0
520+
"step": 1,
521+
"tkv": 64,
522+
"waiting": ["1"],
523+
"running": ["0"],
524+
"request_outputs": ["0"],
525+
"n_reserved_blocks": 3, # prefill (1 block) + 119 decode (2 block)
526+
"n_used_blocks": 1
527+
},
528+
{
529+
# Prefill sequence 1
530+
"step": 2,
531+
"tkv": 64,
532+
"waiting": [],
533+
"running": ["1", "0"],
534+
"request_outputs": ["1"],
535+
"n_reserved_blocks": 5, # prefill (1 block) + 51 decodes (1 block)
536+
"n_used_blocks": 2
537+
},
538+
{
539+
# Decode sequences 0 and 1
540+
"step": 3,
541+
"tkv": 65,
542+
"waiting": [],
543+
"running": ["1", "0"],
544+
"request_outputs": ["1", "0"],
545+
"n_reserved_blocks": 5,
546+
"n_used_blocks": 4 # 2 blocks extended, one for each sequence
547+
},
548+
{
549+
# Sequence 2 joins: one iteration in waiting queue
550+
"step": 32,
551+
"tkv": 94,
552+
"waiting": ["2"],
553+
"running": ["1", "0"],
554+
"request_outputs": ["1", "0"],
555+
"n_reserved_blocks": 5,
556+
"n_used_blocks": 4
557+
},
558+
{
559+
# Prefill sequence 2
560+
"step": 33,
561+
"tkv": 94,
562+
"waiting": [],
563+
"running": ["2", "1", "0"],
564+
"request_outputs": ["2"],
565+
"n_reserved_blocks": 9, # prefill (2 block) + 103 decode (2 block)
566+
"n_used_blocks": 6
567+
},
568+
{
569+
# Decode sequences 0, 1, and 2
570+
"step": 34,
571+
"tkv": 95,
572+
"waiting": [],
573+
"running": ["2", "1", "0"],
574+
"request_outputs": ["2", "1", "0"],
575+
"n_reserved_blocks": 9,
576+
"n_used_blocks": 6
577+
},
578+
{
579+
# Sequence 1 finishes at step 54
580+
# (start step + 2 prefills + 51 decodes - 1) = 2 + 2 + 51 - 1 = 54
581+
"step": 54,
582+
"tkv": 115,
583+
"waiting": [],
584+
"running": ["2", "0"],
585+
"request_outputs": ["2", "1", "0"],
586+
"finished_requests": ["1"],
587+
"n_reserved_blocks": 9,
588+
"n_used_blocks": 6
589+
},
590+
{
591+
# Decode sequences 0 and 2
592+
"step": 55,
593+
"tkv": 116,
594+
"waiting": [],
595+
"running": ["2", "0"],
596+
"request_outputs": ["2", "0"],
597+
"n_reserved_blocks": 7, # two blocks released
598+
"n_used_blocks": 4 # two blocks released
599+
},
600+
{
601+
# Decode sequences 0 and 2, tkv arrives to new block
602+
"step": 68,
603+
"tkv": 129,
604+
"waiting": [],
605+
"running": ["2", "0"],
606+
"request_outputs": ["2", "0"],
607+
"n_reserved_blocks": 7,
608+
"n_used_blocks": 6 # 2 blocks extended, one for each sequence
609+
},
610+
{
611+
# Sequence 0 finishes at step 121
612+
# (start step + 3 prefills + 118 decode - 1) = 1 + 3 + 118 - 1 = 121
613+
"step": 121,
614+
"tkv": 182,
615+
"waiting": [],
616+
"running": ["2"],
617+
"request_outputs": ["2", "0"],
618+
"finished_requests": ["0"],
619+
"n_reserved_blocks": 7,
620+
"n_used_blocks": 6
621+
},
622+
{
623+
# Decode sequence 2
624+
"step": 122,
625+
"tkv": 183,
626+
"waiting": [],
627+
"running": ["2"],
628+
"request_outputs": ["2"],
629+
"n_reserved_blocks": 4, # 3 blocks released
630+
"n_used_blocks": 3 # 3 blocks released
631+
},
632+
{
633+
# Sequence 3 joins: one iteration in waiting queue
634+
"step": 131,
635+
"tkv": 192,
636+
"waiting": ["3"],
637+
"running": ["2"],
638+
"request_outputs": ["2"],
639+
"n_reserved_blocks": 4,
640+
"n_used_blocks": 3
641+
},
642+
{
643+
# Prefill sequence 3
644+
"step": 132,
645+
"tkv": 192,
646+
"waiting": [],
647+
"running": ["3", "2"],
648+
"request_outputs": ["3"],
649+
"n_reserved_blocks": 8, # prefill (3 blocks) + 63 decode (1 block)
650+
"n_used_blocks": 6 # prefill (3 block)
651+
},
652+
{
653+
# Decode sequences 2 and 3
654+
"step": 133,
655+
"tkv": 193,
656+
"waiting": [],
657+
"running": ["3", "2"],
658+
"request_outputs": ["3", "2"],
659+
"n_reserved_blocks": 8,
660+
"n_used_blocks": 8 # 2 blocks extended, one for each sequence
661+
},
662+
{
663+
# Sequence 2 finishes at step 137
664+
# (start step + 2 prefills + 103 decodes) = 33 + 2 + 103 - 1 = 137
665+
"step": 137,
666+
"tkv": 197,
667+
"waiting": [],
668+
"running": ["3"],
669+
"request_outputs": ["3", "2"],
670+
"finished_requests": ["2"],
671+
"n_reserved_blocks": 8,
672+
"n_used_blocks": 8
673+
},
674+
{
675+
# Decode sequence 3
676+
"step": 138,
677+
"tkv": 70,
678+
"waiting": [],
679+
"running": ["3"],
680+
"request_outputs": ["3"],
681+
# 6 blocks freed: finished sequence (4) + left padding stripping (2)
682+
"n_reserved_blocks": 2,
683+
"n_used_blocks": 2
684+
},
685+
{
686+
# Sequence 3 finishes at step 196
687+
# (start step + 1 prefills + 103 decodes) = 132 + 1 + 63 - 1 = 196
688+
"step": 195,
689+
"tkv": 127,
690+
"waiting": [],
691+
"running": [],
692+
"request_outputs": ["3"],
693+
"finished_requests": ["3"],
694+
"n_reserved_blocks": 2,
695+
"n_used_blocks": 2
696+
},
697+
{
698+
# Tkv should be cleared one step later
699+
"step": 196,
700+
"tkv": 0,
701+
"waiting": [],
702+
"running": [],
703+
"request_outputs": [],
704+
"n_reserved_blocks": 0,
705+
"n_used_blocks": 0
706+
},
707+
# TODO this is when max_tokens = 65 for last prompt
708+
# {
709+
# # Sequence 3 finishes at step 196
710+
# # (start step + 1 prefills + 103 decodes) = 132 + 1 + 64 - 1 = 196
711+
# "step": 196,
712+
# "tkv": 128,
713+
# "waiting": [],
714+
# "running": [],
715+
# "request_outputs": ["3"],
716+
# "finished_requests": ["3"],
717+
# "n_reserved_blocks": 2,
718+
# "n_used_blocks": 2
719+
# },
720+
# {
721+
# # Tkv should be cleared one step later
722+
# "step": 197,
723+
# "tkv": 0,
724+
# "waiting": [],
725+
# "running": [],
726+
# "request_outputs": [],
727+
# "n_reserved_blocks": 0,
728+
# "n_used_blocks": 0
729+
# },
730+
]
731+
732+
check_scheduler_inference_steps(
733+
model=model,
734+
backend=backend,
735+
monkeypatch=monkeypatch,
736+
seqs_max_tokens=seqs_max_tokens,
737+
prompts_lengths=prompts_lengths,
738+
steps_add_reqs=steps_add_reqs,
739+
checked_steps=checked_steps,
740+
max_num_seqs=max_num_seqs,
741+
max_model_len=max_model_len,
473742
available_blocks=available_blocks,
474743
use_cb=True,
475744
)
@@ -494,6 +763,7 @@ def test_prompt_too_long_for_current_tkv(model: str, backend: str,
494763
steps_add_reqs = [0, 0]
495764
available_blocks = -1 # no restriction
496765
max_num_seqs = 2
766+
max_model_len = 256
497767

498768
checked_steps = [
499769
{
@@ -617,6 +887,7 @@ def test_prompt_too_long_for_current_tkv(model: str, backend: str,
617887
steps_add_reqs=steps_add_reqs,
618888
checked_steps=checked_steps,
619889
max_num_seqs=max_num_seqs,
890+
max_model_len=max_model_len,
620891
available_blocks=available_blocks,
621892
use_cb=True,
622893
)
@@ -642,6 +913,7 @@ def test_requested_tokens_not_fitting_remaining_space(
642913
steps_add_reqs = [0, 0, 0]
643914
available_blocks = -1 # no restriction
644915
max_num_seqs = 2
916+
max_model_len = 256
645917

646918
checked_steps = [
647919
{
@@ -802,6 +1074,7 @@ def test_requested_tokens_not_fitting_remaining_space(
8021074
steps_add_reqs=steps_add_reqs,
8031075
checked_steps=checked_steps,
8041076
max_num_seqs=max_num_seqs,
1077+
max_model_len=max_model_len,
8051078
available_blocks=available_blocks,
8061079
use_cb=True,
8071080
)
@@ -830,6 +1103,8 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
8301103
# total number of blocks needed if scheduled together : 4 * (1 + 1) = 8
8311104
available_blocks = 8
8321105
max_num_seqs = 4
1106+
max_model_len = 256
1107+
8331108
checked_steps = [
8341109
{
8351110
"step": 0,
@@ -933,6 +1208,7 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
9331208
steps_add_reqs=steps_add_reqs,
9341209
checked_steps=checked_steps,
9351210
max_num_seqs=max_num_seqs,
1211+
max_model_len=max_model_len,
9361212
available_blocks=available_blocks,
9371213
use_cb=True,
9381214
)
@@ -962,6 +1238,8 @@ def test_requests_use_more_than_available_blocks(
9621238
# total number of blocks needed if scheduled together : 4 * (1 + 1) = 8
9631239
available_blocks = 4
9641240
max_num_seqs = 4
1241+
max_model_len = 256
1242+
9651243
checked_steps = [
9661244
{
9671245
"step": 0,
@@ -1090,6 +1368,7 @@ def test_requests_use_more_than_available_blocks(
10901368
steps_add_reqs=steps_add_reqs,
10911369
checked_steps=checked_steps,
10921370
max_num_seqs=max_num_seqs,
1371+
max_model_len=max_model_len,
10931372
available_blocks=available_blocks,
10941373
use_cb=True,
10951374
)

tests/scheduling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def check_scheduler_inference_steps(
4141
steps_add_reqs: list[int],
4242
checked_steps: list[dict[str, Any]],
4343
max_num_seqs: int,
44+
max_model_len: int,
4445
available_blocks: int,
4546
use_cb: bool = True,
4647
):

0 commit comments

Comments
 (0)