13
13
all_to_all_hp2sp , all_to_all_sp2hp )
14
14
15
15
16
- # 测试用的固定数据
17
16
@pytest .fixture
18
17
def test_tensor ():
19
18
return torch .randn (8 , 16 )
@@ -29,7 +28,6 @@ def mock_group():
29
28
return MagicMock ()
30
29
31
30
32
- # 模拟分布式环境
33
31
@pytest .fixture (autouse = True )
34
32
def mock_dist ():
35
33
with patch ("torch.distributed" ) as mock :
@@ -39,12 +37,11 @@ def mock_dist():
39
37
40
38
41
39
class TestDistributedCommunication :
42
- """测试分布式通信函数"""
43
40
44
41
@pytest .mark .parametrize ("world_size" , [1 , 4 ])
45
42
def test_gather_along_first_dim (self , test_tensor , mock_group , mock_dist ,
46
43
world_size ):
47
- """测试_gather_along_first_dim """
44
+ """test _gather_along_first_dim """
48
45
mock_dist .get_world_size .return_value = world_size
49
46
50
47
result = _gather_along_first_dim (test_tensor , mock_group )
@@ -56,7 +53,7 @@ def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist,
56
53
57
54
def test_gather_along_first_dim_unequal_split (self , test_tensor ,
58
55
mock_group ):
59
- """测试不等分分割情况 """
56
+ """test unequal split """
60
57
output_split_sizes = [5 , 10 , 15 , 2 ]
61
58
result = _gather_along_first_dim (test_tensor , mock_group ,
62
59
output_split_sizes )
@@ -65,7 +62,7 @@ def test_gather_along_first_dim_unequal_split(self, test_tensor,
65
62
@pytest .mark .parametrize ("world_size" , [1 , 4 ])
66
63
def test_gather_along_last_dim (self , test_tensor_last_dim , mock_group ,
67
64
mock_dist , world_size ):
68
- """测试_gather_along_last_dim """
65
+ """test _gather_along_last_dim """
69
66
mock_dist .get_world_size .return_value = world_size
70
67
71
68
result = _gather_along_last_dim (test_tensor_last_dim , mock_group )
@@ -100,7 +97,7 @@ def test_reduce_scatter_along_last_dim(self, mock_group):
100
97
])
101
98
def test_wrapper_functions (self , mock_group , func , input_shape ,
102
99
expected_shape ):
103
- """测试包装函数 """
100
+ """test wrapper funcs """
104
101
mod = importlib .import_module (
105
102
'vllm_ascend.distributed.tensor_parallel' )
106
103
globals = mod .__dict__
0 commit comments