1
+ import pytest
2
+ import torch
3
+ import importlib
4
+ from unittest .mock import MagicMock , patch
5
+ from vllm_ascend .distributed .tensor_parallel import (
6
+ _gather_along_first_dim , _gather_along_last_dim ,
7
+ _reduce_scatter_along_first_dim , _reduce_scatter_along_last_dim ,
8
+ all_to_all_sp2hp , all_to_all_hp2sp
9
+ )
10
+
11
+ # 测试用的固定数据
12
+ @pytest .fixture
13
+ def test_tensor ():
14
+ return torch .randn (8 , 16 )
15
+
16
+
17
+ @pytest .fixture
18
+ def test_tensor_last_dim ():
19
+ return torch .randn (8 , 16 , 32 )
20
+
21
+
22
+ @pytest .fixture
23
+ def mock_group ():
24
+ return MagicMock ()
25
+
26
+
27
+ # 模拟分布式环境
28
+ @pytest .fixture (autouse = True )
29
+ def mock_dist ():
30
+ with patch ("torch.distributed" ) as mock :
31
+ mock .get_world_size .return_value = 4
32
+ mock .get_rank .return_value = 0
33
+ yield mock
34
+
35
+
36
+ class TestDistributedCommunication :
37
+ """测试分布式通信函数"""
38
+
39
+ @pytest .mark .parametrize ("world_size" , [1 , 4 ])
40
+ def test_gather_along_first_dim (self , test_tensor , mock_group , mock_dist , world_size ):
41
+ """测试_gather_along_first_dim"""
42
+ mock_dist .get_world_size .return_value = world_size
43
+
44
+ result = _gather_along_first_dim (test_tensor , mock_group )
45
+
46
+ if world_size == 1 :
47
+ assert torch .equal (result , test_tensor )
48
+ else :
49
+ assert result .shape == (32 , 16 ) # 8*4=32
50
+
51
+ def test_gather_along_first_dim_unequal_split (self , test_tensor , mock_group ):
52
+ """测试不等分分割情况"""
53
+ output_split_sizes = [5 , 10 , 15 , 2 ]
54
+ result = _gather_along_first_dim (test_tensor , mock_group , output_split_sizes )
55
+ assert result .shape == (32 , 16 ) # 5+10+15+2=32
56
+
57
+ @pytest .mark .parametrize ("world_size" , [1 , 4 ])
58
+ def test_gather_along_last_dim (self , test_tensor_last_dim , mock_group , mock_dist , world_size ):
59
+ """测试_gather_along_last_dim"""
60
+ mock_dist .get_world_size .return_value = world_size
61
+
62
+ result = _gather_along_last_dim (test_tensor_last_dim , mock_group )
63
+
64
+ if world_size == 1 :
65
+ assert torch .equal (result , test_tensor_last_dim )
66
+ else :
67
+ assert result .shape == (8 , 16 , 32 * world_size ) # 8*4=32
68
+
69
+ @pytest .mark .parametrize ("input_shape,expected_shape" , [
70
+ ((32 , 16 ), (8 , 16 )),
71
+ ((40 , 10 ), (10 , 10 )),
72
+ ])
73
+ def test_reduce_scatter_along_first_dim (self , mock_group , input_shape , expected_shape ):
74
+ input_tensor = torch .randn (* input_shape )
75
+ result = _reduce_scatter_along_first_dim (input_tensor , mock_group )
76
+ assert result .shape == expected_shape
77
+
78
+ def test_reduce_scatter_along_last_dim (self , mock_group ):
79
+ input_tensor = torch .randn (8 , 16 , 32 )
80
+ result = _reduce_scatter_along_last_dim (input_tensor , mock_group )
81
+ assert result .shape == (8 , 16 , 8 ) # 32/4=8
82
+
83
+ @pytest .mark .parametrize ("func,input_shape,expected_shape" , [
84
+ ("all_gather_last_dim_from_tensor_parallel_region" , (8 , 16 , 32 ), (8 , 16 , 128 )),
85
+ ("reduce_scatter_to_sequence_parallel_region" , (32 , 16 ), (8 , 16 )),
86
+ ("reduce_scatter_last_dim_to_tensor_parallel_region" , (8 , 16 , 32 ), (8 , 16 , 8 )),
87
+ ("gather_from_sequence_parallel_region" , (8 , 16 ), (32 , 16 )),
88
+ ])
89
+ def test_wrapper_functions (self , mock_group , func , input_shape , expected_shape ):
90
+ """测试包装函数"""
91
+ mod = importlib .import_module ('vllm_ascend.distributed.tensor_parallel' )
92
+ globals = mod .__dict__
93
+ test_func = globals [func ]
94
+ input_tensor = torch .randn (* input_shape )
95
+ result = test_func (input_tensor , mock_group )
96
+ assert result .shape == expected_shape
97
+
98
+
99
+ @pytest .mark .parametrize ("input_shape,output_shape" , [
100
+ ((8 , 16 ), (32 , 4 )), # [num_tokens/TP, H] -> [num_tokens, H/TP]
101
+ ])
102
+ def test_all_to_all_sp2hp (self , mock_group , input_shape , output_shape ):
103
+ input_tensor = torch .randn (* input_shape )
104
+ result = all_to_all_sp2hp (input_tensor , mock_group )
105
+ assert result .shape == output_shape
106
+
107
+
108
+ @pytest .mark .parametrize ("input_shape,output_shape" , [
109
+ ((32 , 4 ), (8 , 16 )), # [num_tokens, H/TP] -> [num_tokens/TP, H]
110
+ ])
111
+ def test_all_to_all_hp2sp (self , mock_group , input_shape , output_shape ):
112
+ input_tensor = torch .randn (* input_shape )
113
+ result = all_to_all_hp2sp (input_tensor , mock_group )
114
+ assert result .shape == output_shape
0 commit comments