Skip to content

Commit d922fb9

Browse files
authored
[3/N][CI/UT] add spec decode e2e UT && [BUGFIX] fix init_logger bug (#487)
### What this PR does / why we need it? add spec decode e2e UT: 1. add `test_eagle_correctness.py`; 2. add `test_mtp_correctness.py`; use bf16 model weights; fix init_logger bug: 1. slove OOM probelm in `camem.py`; 2. replace `from vllm.logger import init_logger` to `from vllm.logger import logger` overall situation ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? Local verification passed Signed-off-by: mengwei805 <mengwei25@huawei.com>
1 parent 2f15503 commit d922fb9

File tree

15 files changed

+853
-35
lines changed

15 files changed

+853
-35
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ jobs:
148148
- name: Run vllm-project/vllm-ascend key feature test
149149
if: steps.filter.outputs.speculative_tests_changed
150150
run: |
151-
pytest -sv tests/spec_decode
151+
pytest -sv tests/spec_decode/e2e/test_mtp_correctness.py
152+
pytest -sv tests/spec_decode --ignore=tests/spec_decode/e2e/test_mtp_correctness.py
152153
153154
- name: Run vllm-project/vllm test
154155
run: |

tests/conftest.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@
1717
# limitations under the License.
1818
#
1919

20+
import contextlib
21+
import gc
2022
from typing import List, Optional, Tuple, TypeVar, Union
2123

2224
import numpy as np
2325
import pytest
26+
import torch
2427
from PIL import Image
2528
from vllm import LLM, SamplingParams
2629
from vllm.config import TaskOption
27-
from vllm.distributed import cleanup_dist_env_and_memory
30+
from vllm.distributed import (destroy_distributed_environment,
31+
destroy_model_parallel)
2832
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
2933
from vllm.logger import init_logger
3034
from vllm.outputs import RequestOutput
@@ -44,6 +48,15 @@
4448
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
4549

4650

51+
def cleanup_dist_env_and_memory():
52+
destroy_model_parallel()
53+
destroy_distributed_environment()
54+
with contextlib.suppress(AssertionError):
55+
torch.distributed.destroy_process_group()
56+
gc.collect()
57+
torch.npu.empty_cache()
58+
59+
4760
class VllmRunner:
4861

4962
def __init__(

0 commit comments

Comments
 (0)