Skip to content

Commit c92d4c9

Browse files
committed
[FOLLOWUP] Use base test
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
1 parent 0c1d239 commit c92d4c9

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

tests/ut/ops/test_rotary_embedding.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
yarn_get_mscale)
1313

1414

15-
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
15+
class TestCustomRotaryEmbeddingEnabled(TestBase):
1616

1717
def setUp(self):
1818
# Common setup for tests
@@ -67,7 +67,7 @@ def test_custom_rotary_embedding_enabled(self):
6767
self.assertFalse(result)
6868

6969

70-
class TestRopeForwardOot(unittest.TestCase):
70+
class TestRopeForwardOot(TestBase):
7171

7272
def setUp(self):
7373
# Common setup for tests
@@ -262,7 +262,7 @@ def test_native_rope_deepseek_forward_non_neox_style(
262262
assert k_pe.shape == key.shape
263263

264264

265-
class TestRotateHalf(unittest.TestCase):
265+
class TestRotateHalf(TestBase):
266266

267267
def test_rotate_half_even_dim(self):
268268
# Test with even dimension
@@ -272,7 +272,7 @@ def test_rotate_half_even_dim(self):
272272
self.assertTrue(torch.allclose(result, expected))
273273

274274

275-
class TestYarnFindCorrectionDim(unittest.TestCase):
275+
class TestYarnFindCorrectionDim(TestBase):
276276

277277
def test_basic_case(self):
278278
# Test with standard values
@@ -293,7 +293,7 @@ def test_basic_case(self):
293293
self.assertTrue(torch.allclose(result, expected))
294294

295295

296-
class TestYarnGetMscale(unittest.TestCase):
296+
class TestYarnGetMscale(TestBase):
297297

298298
def test_scale_less_than_or_equal_1(self):
299299
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)

tests/ut/test_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from threading import Lock
2020
from unittest import mock
2121

22+
from tests.ut.base import TestBase
2223
import torch
2324
from vllm.config import (CompilationConfig, ModelConfig, ParallelConfig,
2425
VllmConfig)
@@ -281,7 +282,7 @@ def test_update_aclgraph_sizes(self):
281282
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
282283

283284

284-
class TestProfileExecuteDuration(unittest.TestCase):
285+
class TestProfileExecuteDuration(TestBase):
285286

286287
def setUp(self):
287288
utils.ProfileExecuteDuration._instance = None

tests/ut/worker/test_input_batch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from vllm.v1.sample.metadata import SamplingMetadata
77
from vllm.v1.worker.block_table import MultiGroupBlockTable
88

9+
from tests.ut.base import TestBase
910
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
1011

1112

@@ -24,7 +25,7 @@ def mock_cached_request_state(req_id="1", prompt=[1, 2, 3], output=[4, 5, 6]):
2425
)
2526

2627

27-
class TestInputBatch(unittest.TestCase):
28+
class TestInputBatch(TestBase):
2829

2930
def setUp(self):
3031
self.max_num_reqs = 10

0 commit comments

Comments
 (0)