Skip to content

Commit 4c41672

Browse files
mengwei805XWFAlone
andauthored
[0.7.3][5/N][CI/UT]add spec decode e2e UT && [BUGFIX]fix chunk prefill bug (#560)
### What this PR does / why we need it? add spec decode e2e UT 1. add `test_multistep_correctness.py`; 2. open `tests/spec_decode/e2e/test_eagle_correctness.py` 2 cases by using modelscope weights; fix chunked prefill bug 1. add support for `atten_mask` only has 1 element ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? tested by CI and local test passed. Signed-off-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: XWFAlone <xuewenfei2@huawei.com>
1 parent 320e823 commit 4c41672

File tree

4 files changed

+837
-27
lines changed

4 files changed

+837
-27
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@ jobs:
125125
- name: Run vllm-project/vllm-ascend key feature test
126126
if: steps.filter.outputs.speculative_tests_changed == 'true'
127127
run: |
128-
pytest -sv tests/spec_decode/e2e/test_mtp_correctness.py
129-
pytest -sv tests/spec_decode --ignore=tests/spec_decode/e2e/test_mtp_correctness.py
128+
pytest -sv tests/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process
129+
pytest -sv tests/spec_decode/e2e/test_multistep_correctness.py # it needs a clean process
130+
pytest -sv tests/spec_decode --ignore=tests/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/spec_decode/e2e/test_multistep_correctness.py
130131
131132
- name: Run vllm-project/vllm test
132133
run: |

tests/spec_decode/e2e/test_eagle_correctness.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,13 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
350350
"dtype": "float16",
351351
352352
# Main model
353-
"model_name": "meta-llama/Llama-2-7b-chat-hf"
353+
"model_name": "vllm-ascend/Llama-2-7b-chat-hf"
354354
}])
355355
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
356356
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
357357
@pytest.mark.parametrize("test_llm_kwargs", [
358358
{
359-
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
359+
"speculative_model": "vllm-ascend/EAGLE-llama2-chat-7B",
360360
"num_speculative_tokens": MAX_SPEC_TOKENS,
361361
},
362362
])
@@ -368,21 +368,25 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
368368
])
369369
@pytest.mark.parametrize("batch_size", [1, 5])
370370
@pytest.mark.parametrize("seed", [1])
371-
def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
371+
def test_llama2_eagle_e2e_greedy_correctness(monkeypatch: pytest.MonkeyPatch,
372+
vllm_runner, common_llm_kwargs,
372373
per_test_common_llm_kwargs,
373374
baseline_llm_kwargs,
374375
test_llm_kwargs, batch_size: int,
375376
output_len: int, seed: int):
376377

377-
run_equality_correctness_test(vllm_runner,
378-
common_llm_kwargs,
379-
per_test_common_llm_kwargs,
380-
baseline_llm_kwargs,
381-
test_llm_kwargs,
382-
batch_size,
383-
output_len,
384-
seed,
385-
temperature=0.0)
378+
# TODO: it is a wrong way to use modelscope.
379+
with monkeypatch.context() as m:
380+
m.setenv("VLLM_USE_MODELSCOPE", "True")
381+
run_equality_correctness_test(vllm_runner,
382+
common_llm_kwargs,
383+
per_test_common_llm_kwargs,
384+
baseline_llm_kwargs,
385+
test_llm_kwargs,
386+
batch_size,
387+
output_len,
388+
seed,
389+
temperature=0.0)
386390

387391

388392
@pytest.mark.skipif(True, reason="Open it when CI could use modelscope")
@@ -399,13 +403,13 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
399403
"dtype": "float16",
400404
401405
# Main model
402-
"model_name": "meta-llama/Meta-Llama-3-8B-Instruct"
406+
"model_name": "vllm-ascend/Meta-Llama-3-8B-Instruct"
403407
}])
404408
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
405409
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
406410
@pytest.mark.parametrize("test_llm_kwargs", [
407411
{
408-
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
412+
"speculative_model": "vllm-ascend/EAGLE-LLaMA3-Instruct-8B",
409413
"num_speculative_tokens": MAX_SPEC_TOKENS,
410414
},
411415
])
@@ -417,21 +421,25 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
417421
])
418422
@pytest.mark.parametrize("batch_size", [1, 5])
419423
@pytest.mark.parametrize("seed", [1])
420-
def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
424+
def test_llama3_eagle_e2e_greedy_correctness(monkeypatch: pytest.MonkeyPatch,
425+
vllm_runner, common_llm_kwargs,
421426
per_test_common_llm_kwargs,
422427
baseline_llm_kwargs,
423428
test_llm_kwargs, batch_size: int,
424429
output_len: int, seed: int):
425430

426-
run_equality_correctness_test(vllm_runner,
427-
common_llm_kwargs,
428-
per_test_common_llm_kwargs,
429-
baseline_llm_kwargs,
430-
test_llm_kwargs,
431-
batch_size,
432-
output_len,
433-
seed,
434-
temperature=0.0)
431+
# TODO: it is a wrong way to use modelscope.
432+
with monkeypatch.context() as m:
433+
m.setenv("VLLM_USE_MODELSCOPE", "True")
434+
run_equality_correctness_test(vllm_runner,
435+
common_llm_kwargs,
436+
per_test_common_llm_kwargs,
437+
baseline_llm_kwargs,
438+
test_llm_kwargs,
439+
batch_size,
440+
output_len,
441+
seed,
442+
temperature=0.0)
435443

436444

437445
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)