Skip to content

Commit ac373c0

Browse files
committed
compare results naive
Signed-off-by: Sophie du Couédic <sop@zurich.ibm.com>
1 parent 96e3175 commit ac373c0

File tree

3 files changed

+193
-38
lines changed

3 files changed

+193
-38
lines changed

tests/e2e/test_spyre_cb_scheduler_steps.py

Lines changed: 138 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
import pytest
1010
from scheduling_utils import check_scheduler_inference_steps
11-
from spyre_util import get_spyre_backend_list, get_spyre_model_list
11+
from spyre_util import (compare_results, generate_hf_output,
12+
get_spyre_backend_list, get_spyre_model_list)
1213

1314

1415
@pytest.mark.cb
@@ -34,6 +35,8 @@ def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
3435
available_blocks = -1 # no restriction
3536
max_num_seqs = 2
3637
max_model_len = 256
38+
# check_output = backend == "sendnn"
39+
check_output = True
3740

3841
checked_steps = [
3942
{
@@ -162,7 +165,7 @@ def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
162165
},
163166
]
164167

165-
check_scheduler_inference_steps(
168+
cb_outputs, prompts = check_scheduler_inference_steps(
166169
model=model,
167170
backend=backend,
168171
monkeypatch=monkeypatch,
@@ -174,8 +177,22 @@ def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
174177
max_model_len=max_model_len,
175178
available_blocks=available_blocks,
176179
use_cb=True,
180+
collect_outputs=check_output,
177181
)
178182

183+
if check_output:
184+
hf_outputs = generate_hf_output(
185+
model=model,
186+
prompts=prompts,
187+
max_new_tokens=seqs_max_tokens,
188+
ignore_eos=True,
189+
)
190+
compare_results(model=model,
191+
tensor_parallel_size=1,
192+
backend=backend,
193+
vllm_results=cb_outputs,
194+
hf_results=hf_outputs)
195+
179196

180197
@pytest.mark.cb
181198
@pytest.mark.parametrize("model", get_spyre_model_list())
@@ -200,6 +217,8 @@ def test_prompts_misaligned_with_tkv_boundaries(
200217
available_blocks = -1 # no restriction
201218
max_num_seqs = 2
202219
max_model_len = 256
220+
# check_output = backend == "sendnn"
221+
check_output = True
203222

204223
checked_steps = [
205224
{
@@ -326,7 +345,7 @@ def test_prompts_misaligned_with_tkv_boundaries(
326345
},
327346
]
328347

329-
check_scheduler_inference_steps(
348+
cb_outputs, prompts = check_scheduler_inference_steps(
330349
model=model,
331350
backend=backend,
332351
monkeypatch=monkeypatch,
@@ -338,8 +357,22 @@ def test_prompts_misaligned_with_tkv_boundaries(
338357
max_model_len=max_model_len,
339358
available_blocks=available_blocks,
340359
use_cb=True,
360+
collect_outputs=check_output,
341361
)
342362

363+
if check_output:
364+
hf_outputs = generate_hf_output(
365+
model=model,
366+
prompts=prompts,
367+
max_new_tokens=seqs_max_tokens,
368+
ignore_eos=True,
369+
)
370+
compare_results(model=model,
371+
tensor_parallel_size=1,
372+
backend=backend,
373+
vllm_results=cb_outputs,
374+
hf_results=hf_outputs)
375+
343376

344377
@pytest.mark.cb
345378
@pytest.mark.parametrize("model", get_spyre_model_list())
@@ -363,6 +396,8 @@ def test_two_sequences_finish_same_time_as_new_arrive(
363396
available_blocks = -1 # no restriction
364397
max_num_seqs = 2
365398
max_model_len = 256
399+
# check_output = backend == "sendnn"
400+
check_output = True
366401

367402
checked_steps = [
368403
{
@@ -466,7 +501,7 @@ def test_two_sequences_finish_same_time_as_new_arrive(
466501
},
467502
]
468503

469-
check_scheduler_inference_steps(
504+
cb_outputs, prompts = check_scheduler_inference_steps(
470505
model=model,
471506
backend=backend,
472507
monkeypatch=monkeypatch,
@@ -478,8 +513,22 @@ def test_two_sequences_finish_same_time_as_new_arrive(
478513
max_model_len=max_model_len,
479514
available_blocks=available_blocks,
480515
use_cb=True,
516+
collect_outputs=check_output,
481517
)
482518

519+
if check_output:
520+
hf_outputs = generate_hf_output(
521+
model=model,
522+
prompts=prompts,
523+
max_new_tokens=seqs_max_tokens,
524+
ignore_eos=True,
525+
)
526+
compare_results(model=model,
527+
tensor_parallel_size=1,
528+
backend=backend,
529+
vllm_results=cb_outputs,
530+
hf_results=hf_outputs)
531+
483532

484533
@pytest.mark.cb
485534
@pytest.mark.parametrize("model", get_spyre_model_list())
@@ -504,6 +553,8 @@ def test_new_sequence_joins_during_decode(model: str, backend: str,
504553
available_blocks = -1 # no restriction
505554
max_num_seqs = 4
506555
max_model_len = 256
556+
# check_output = backend == "sendnn"
557+
check_output = True
507558

508559
checked_steps = [
509560
{
@@ -729,7 +780,7 @@ def test_new_sequence_joins_during_decode(model: str, backend: str,
729780
# },
730781
]
731782

732-
check_scheduler_inference_steps(
783+
cb_outputs, prompts = check_scheduler_inference_steps(
733784
model=model,
734785
backend=backend,
735786
monkeypatch=monkeypatch,
@@ -741,8 +792,22 @@ def test_new_sequence_joins_during_decode(model: str, backend: str,
741792
max_model_len=max_model_len,
742793
available_blocks=available_blocks,
743794
use_cb=True,
795+
collect_outputs=check_output,
744796
)
745797

798+
if check_output:
799+
hf_outputs = generate_hf_output(
800+
model=model,
801+
prompts=prompts,
802+
max_new_tokens=seqs_max_tokens,
803+
ignore_eos=True,
804+
)
805+
compare_results(model=model,
806+
tensor_parallel_size=1,
807+
backend=backend,
808+
vllm_results=cb_outputs,
809+
hf_results=hf_outputs)
810+
746811

747812
@pytest.mark.cb
748813
@pytest.mark.parametrize("model", get_spyre_model_list())
@@ -764,6 +829,7 @@ def test_prompt_too_long_for_current_tkv(model: str, backend: str,
764829
available_blocks = -1 # no restriction
765830
max_num_seqs = 2
766831
max_model_len = 256
832+
check_output = False
767833

768834
checked_steps = [
769835
{
@@ -878,7 +944,7 @@ def test_prompt_too_long_for_current_tkv(model: str, backend: str,
878944
},
879945
]
880946

881-
check_scheduler_inference_steps(
947+
cb_outputs, prompts = check_scheduler_inference_steps(
882948
model=model,
883949
backend=backend,
884950
monkeypatch=monkeypatch,
@@ -890,15 +956,30 @@ def test_prompt_too_long_for_current_tkv(model: str, backend: str,
890956
max_model_len=max_model_len,
891957
available_blocks=available_blocks,
892958
use_cb=True,
959+
collect_outputs=check_output,
893960
)
894961

962+
if check_output:
963+
hf_outputs = generate_hf_output(
964+
model=model,
965+
prompts=prompts,
966+
max_new_tokens=seqs_max_tokens,
967+
ignore_eos=True,
968+
)
969+
compare_results(model=model,
970+
tensor_parallel_size=1,
971+
backend=backend,
972+
vllm_results=cb_outputs,
973+
hf_results=hf_outputs)
974+
895975

896976
@pytest.mark.cb
897977
@pytest.mark.parametrize("model", get_spyre_model_list())
898978
@pytest.mark.parametrize("backend", get_spyre_backend_list())
899979
def test_requested_tokens_not_fitting_remaining_space(
900980
model: str, backend: str, monkeypatch: pytest.MonkeyPatch):
901-
""" Scenario where the request goes beyond max_model_len
981+
""" Scenario where the request goes beyond max_model_len and needs to wait
982+
for a new batch.
902983
903984
Configuration:
904985
* max_num_seqs: 2
@@ -914,6 +995,7 @@ def test_requested_tokens_not_fitting_remaining_space(
914995
available_blocks = -1 # no restriction
915996
max_num_seqs = 2
916997
max_model_len = 256
998+
check_output = False
917999

9181000
checked_steps = [
9191001
{
@@ -1065,7 +1147,7 @@ def test_requested_tokens_not_fitting_remaining_space(
10651147
},
10661148
]
10671149

1068-
check_scheduler_inference_steps(
1150+
cb_outputs, prompts = check_scheduler_inference_steps(
10691151
model=model,
10701152
backend=backend,
10711153
monkeypatch=monkeypatch,
@@ -1077,8 +1159,22 @@ def test_requested_tokens_not_fitting_remaining_space(
10771159
max_model_len=max_model_len,
10781160
available_blocks=available_blocks,
10791161
use_cb=True,
1162+
collect_outputs=check_output,
10801163
)
10811164

1165+
if check_output:
1166+
hf_outputs = generate_hf_output(
1167+
model=model,
1168+
prompts=prompts,
1169+
max_new_tokens=seqs_max_tokens,
1170+
ignore_eos=True,
1171+
)
1172+
compare_results(model=model,
1173+
tensor_parallel_size=1,
1174+
backend=backend,
1175+
vllm_results=cb_outputs,
1176+
hf_results=hf_outputs)
1177+
10821178

10831179
@pytest.mark.cb
10841180
@pytest.mark.parametrize("model", get_spyre_model_list())
@@ -1104,6 +1200,8 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
11041200
available_blocks = 8
11051201
max_num_seqs = 4
11061202
max_model_len = 256
1203+
# check_output = backend == "sendnn"
1204+
check_output = True
11071205

11081206
checked_steps = [
11091207
{
@@ -1199,7 +1297,7 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
11991297
},
12001298
]
12011299

1202-
check_scheduler_inference_steps(
1300+
cb_outputs, prompts = check_scheduler_inference_steps(
12031301
model=model,
12041302
backend=backend,
12051303
monkeypatch=monkeypatch,
@@ -1211,8 +1309,22 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
12111309
max_model_len=max_model_len,
12121310
available_blocks=available_blocks,
12131311
use_cb=True,
1312+
collect_outputs=check_output,
12141313
)
12151314

1315+
if check_output:
1316+
hf_outputs = generate_hf_output(
1317+
model=model,
1318+
prompts=prompts,
1319+
max_new_tokens=seqs_max_tokens,
1320+
ignore_eos=True,
1321+
)
1322+
compare_results(model=model,
1323+
tensor_parallel_size=1,
1324+
backend=backend,
1325+
vllm_results=cb_outputs,
1326+
hf_results=hf_outputs)
1327+
12161328

12171329
@pytest.mark.cb
12181330
@pytest.mark.parametrize("model", get_spyre_model_list())
@@ -1239,6 +1351,8 @@ def test_requests_use_more_than_available_blocks(
12391351
available_blocks = 4
12401352
max_num_seqs = 4
12411353
max_model_len = 256
1354+
# check_output = backend == "sendnn"
1355+
check_output = True
12421356

12431357
checked_steps = [
12441358
{
@@ -1359,7 +1473,7 @@ def test_requests_use_more_than_available_blocks(
13591473
},
13601474
]
13611475

1362-
check_scheduler_inference_steps(
1476+
cb_outputs, prompts = check_scheduler_inference_steps(
13631477
model=model,
13641478
backend=backend,
13651479
monkeypatch=monkeypatch,
@@ -1371,4 +1485,18 @@ def test_requests_use_more_than_available_blocks(
13711485
max_model_len=max_model_len,
13721486
available_blocks=available_blocks,
13731487
use_cb=True,
1488+
collect_outputs=check_output,
13741489
)
1490+
1491+
if check_output:
1492+
hf_outputs = generate_hf_output(
1493+
model=model,
1494+
prompts=prompts,
1495+
max_new_tokens=seqs_max_tokens,
1496+
ignore_eos=True,
1497+
)
1498+
compare_results(model=model,
1499+
tensor_parallel_size=1,
1500+
backend=backend,
1501+
vllm_results=cb_outputs,
1502+
hf_results=hf_outputs)

tests/scheduling_utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,8 @@ def check_scheduler_inference_steps(
8787
"List of checked steps needs to be of increasing order of step")
8888
# ------
8989

90-
collected_outputs = defaultdict(lambda: {"tokens_ids": [], "logprobs": []})
90+
collected_outputs = defaultdict(lambda: {"token_ids": [], "logprobs": []})
9191
generated_prompts = []
92-
prompts_sampling_params = []
9392

9493
# Setup the engine
9594
engine_args = EngineArgs(model=model,
@@ -122,7 +121,6 @@ def check_scheduler_inference_steps(
122121
model=model)
123122
requests.append((add_step, request))
124123
generated_prompts.append(request.prompt_token_ids)
125-
prompts_sampling_params.append(sampling_params)
126124

127125
# In-between steps are added as normal decode steps
128126
checked_steps = augment_checked_steps(checked_steps)
@@ -202,16 +200,25 @@ def check_scheduler_inference_steps(
202200
new_logprobs = output.new_logprobs.logprobs
203201
assert len(new_token_ids) == 1 and len(new_logprobs) == 1
204202

205-
collected_outputs[output.request_id]["tokens_ids"].append(
203+
collected_outputs[output.request_id]["token_ids"].append(
206204
new_token_ids[0])
207205
collected_outputs[output.request_id]["logprobs"].append(
208206
new_logprobs[0][0])
209207

210208
# Return collected outputs as list
211209
if not collected_outputs:
212-
return [], generated_prompts, prompts_sampling_params
210+
return [], generated_prompts
213211
else:
214212
output_keys = sorted(int(k) for k in collected_outputs)
215213
assert output_keys[0] == 0 and output_keys[-1] == len(output_keys) - 1
216-
collected_outputs = [collected_outputs[str(k)] for k in output_keys]
217-
return collected_outputs, generated_prompts, prompts_sampling_params
214+
215+
# convert dict of dicts to ordered list and make values immutable
216+
collected_outputs_new = []
217+
for k in output_keys:
218+
output = collected_outputs[str(k)]
219+
for k, list_values in output.items():
220+
if isinstance(list_values, list):
221+
output[k] = tuple(list_values)
222+
collected_outputs_new.append(output)
223+
224+
return collected_outputs_new, generated_prompts

0 commit comments

Comments
 (0)