Skip to content

Commit fd3cdbe

Browse files
committed
[FOLLOWUP] Use base test
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
1 parent 0c1d239 commit fd3cdbe

File tree

6 files changed

+14
-17
lines changed

6 files changed

+14
-17
lines changed

tests/ut/distributed/kv_transfer/test_simple_buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
import unittest
21
import zlib
32
from unittest.mock import MagicMock
43

54
import torch
65

6+
from tests.ut.base import TestBase
77
from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer,
88
int32_hash)
99

@@ -17,7 +17,7 @@ def __init__(self):
1717
self.deallocate_buffer = MagicMock()
1818

1919

20-
class TestSimpleBuffer(unittest.TestCase):
20+
class TestSimpleBuffer(TestBase):
2121

2222
def setUp(self):
2323
self.pipe = MockSimplePipe()

tests/ut/distributed/kv_transfer/test_simple_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
import unittest
21
from unittest.mock import MagicMock, patch
32

43
import torch
54
from vllm.config import VllmConfig
65
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
76

7+
from tests.ut.base import TestBase
88
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
99
from vllm_ascend.distributed.kv_transfer.simple_connector import \
1010
SimpleConnector
1111
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
1212

1313

14-
class TestSimpleConnector(unittest.TestCase):
14+
class TestSimpleConnector(TestBase):
1515

1616
def setUp(self):
1717
self.mock_pipe = MagicMock(spec=SimplePipe)

tests/ut/distributed/kv_transfer/test_simple_pipe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
import unittest
21
from unittest.mock import MagicMock, patch
32

43
import torch
54

5+
from tests.ut.base import TestBase
66
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
77

88

9-
class TestSimplePipe(unittest.TestCase):
9+
class TestSimplePipe(TestBase):
1010

1111
@classmethod
1212
def _create_mock_config(self):

tests/ut/ops/test_rotary_embedding.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
import unittest
32
from unittest.mock import MagicMock, patch
43

54
import torch
@@ -12,7 +11,7 @@
1211
yarn_get_mscale)
1312

1413

15-
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
14+
class TestCustomRotaryEmbeddingEnabled(TestBase):
1615

1716
def setUp(self):
1817
# Common setup for tests
@@ -67,7 +66,7 @@ def test_custom_rotary_embedding_enabled(self):
6766
self.assertFalse(result)
6867

6968

70-
class TestRopeForwardOot(unittest.TestCase):
69+
class TestRopeForwardOot(TestBase):
7170

7271
def setUp(self):
7372
# Common setup for tests
@@ -262,7 +261,7 @@ def test_native_rope_deepseek_forward_non_neox_style(
262261
assert k_pe.shape == key.shape
263262

264263

265-
class TestRotateHalf(unittest.TestCase):
264+
class TestRotateHalf(TestBase):
266265

267266
def test_rotate_half_even_dim(self):
268267
# Test with even dimension
@@ -272,7 +271,7 @@ def test_rotate_half_even_dim(self):
272271
self.assertTrue(torch.allclose(result, expected))
273272

274273

275-
class TestYarnFindCorrectionDim(unittest.TestCase):
274+
class TestYarnFindCorrectionDim(TestBase):
276275

277276
def test_basic_case(self):
278277
# Test with standard values
@@ -293,7 +292,7 @@ def test_basic_case(self):
293292
self.assertTrue(torch.allclose(result, expected))
294293

295294

296-
class TestYarnGetMscale(unittest.TestCase):
295+
class TestYarnGetMscale(TestBase):
297296

298297
def test_scale_less_than_or_equal_1(self):
299298
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)

tests/ut/test_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import math
1717
import os
18-
import unittest
1918
from threading import Lock
2019
from unittest import mock
2120

@@ -281,7 +280,7 @@ def test_update_aclgraph_sizes(self):
281280
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
282281

283282

284-
class TestProfileExecuteDuration(unittest.TestCase):
283+
class TestProfileExecuteDuration(TestBase):
285284

286285
def setUp(self):
287286
utils.ProfileExecuteDuration._instance = None

tests/ut/worker/test_input_batch.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import unittest
2-
31
import numpy as np
42
import torch
53
from vllm.sampling_params import SamplingParams
64
from vllm.v1.sample.metadata import SamplingMetadata
75
from vllm.v1.worker.block_table import MultiGroupBlockTable
86

7+
from tests.ut.base import TestBase
98
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
109

1110

@@ -24,7 +23,7 @@ def mock_cached_request_state(req_id="1", prompt=[1, 2, 3], output=[4, 5, 6]):
2423
)
2524

2625

27-
class TestInputBatch(unittest.TestCase):
26+
class TestInputBatch(TestBase):
2827

2928
def setUp(self):
3029
self.max_num_reqs = 10

0 commit comments

Comments
 (0)