Skip to content

Commit 3b7269a

Browse files
committed
fix comment
Signed-off-by: duyangkai <duyangkai@huawei.com>
1 parent 847d52d commit 3b7269a

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

tests/ut/test_distributed_tensor_parallel.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
all_to_all_hp2sp, all_to_all_sp2hp)
1414

1515

16-
# 测试用的固定数据
1716
@pytest.fixture
1817
def test_tensor():
1918
return torch.randn(8, 16)
@@ -29,7 +28,6 @@ def mock_group():
2928
return MagicMock()
3029

3130

32-
# 模拟分布式环境
3331
@pytest.fixture(autouse=True)
3432
def mock_dist():
3533
with patch("torch.distributed") as mock:
@@ -39,12 +37,11 @@ def mock_dist():
3937

4038

4139
class TestDistributedCommunication:
42-
"""测试分布式通信函数"""
4340

4441
@pytest.mark.parametrize("world_size", [1, 4])
4542
def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist,
4643
world_size):
47-
"""测试_gather_along_first_dim"""
44+
"""test _gather_along_first_dim"""
4845
mock_dist.get_world_size.return_value = world_size
4946

5047
result = _gather_along_first_dim(test_tensor, mock_group)
@@ -56,7 +53,7 @@ def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist,
5653

5754
def test_gather_along_first_dim_unequal_split(self, test_tensor,
5855
mock_group):
59-
"""测试不等分分割情况"""
56+
"""test unequal split"""
6057
output_split_sizes = [5, 10, 15, 2]
6158
result = _gather_along_first_dim(test_tensor, mock_group,
6259
output_split_sizes)
@@ -65,7 +62,7 @@ def test_gather_along_first_dim_unequal_split(self, test_tensor,
6562
@pytest.mark.parametrize("world_size", [1, 4])
6663
def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group,
6764
mock_dist, world_size):
68-
"""测试_gather_along_last_dim"""
65+
"""test _gather_along_last_dim"""
6966
mock_dist.get_world_size.return_value = world_size
7067

7168
result = _gather_along_last_dim(test_tensor_last_dim, mock_group)
@@ -100,7 +97,7 @@ def test_reduce_scatter_along_last_dim(self, mock_group):
10097
])
10198
def test_wrapper_functions(self, mock_group, func, input_shape,
10299
expected_shape):
103-
"""测试包装函数"""
100+
"""test wrapper funcs"""
104101
mod = importlib.import_module(
105102
'vllm_ascend.distributed.tensor_parallel')
106103
globals = mod.__dict__

0 commit comments

Comments
 (0)