Skip to content

Commit d470373

Browse files
committed
ut test
1 parent fbe0c59 commit d470373

File tree

1 file changed

+114
-0
lines changed

1 file changed

+114
-0
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import pytest
2+
import torch
3+
import importlib
4+
from unittest.mock import MagicMock, patch
5+
from vllm_ascend.distributed.tensor_parallel import (
6+
_gather_along_first_dim, _gather_along_last_dim,
7+
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim,
8+
all_to_all_sp2hp, all_to_all_hp2sp
9+
)
10+
11+
# 测试用的固定数据
12+
@pytest.fixture
13+
def test_tensor():
14+
return torch.randn(8, 16)
15+
16+
17+
@pytest.fixture
18+
def test_tensor_last_dim():
19+
return torch.randn(8, 16, 32)
20+
21+
22+
@pytest.fixture
23+
def mock_group():
24+
return MagicMock()
25+
26+
27+
# 模拟分布式环境
28+
@pytest.fixture(autouse=True)
29+
def mock_dist():
30+
with patch("torch.distributed") as mock:
31+
mock.get_world_size.return_value = 4
32+
mock.get_rank.return_value = 0
33+
yield mock
34+
35+
36+
class TestDistributedCommunication:
37+
"""测试分布式通信函数"""
38+
39+
@pytest.mark.parametrize("world_size", [1, 4])
40+
def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist, world_size):
41+
"""测试_gather_along_first_dim"""
42+
mock_dist.get_world_size.return_value = world_size
43+
44+
result = _gather_along_first_dim(test_tensor, mock_group)
45+
46+
if world_size == 1:
47+
assert torch.equal(result, test_tensor)
48+
else:
49+
assert result.shape == (32, 16) # 8*4=32
50+
51+
def test_gather_along_first_dim_unequal_split(self, test_tensor, mock_group):
52+
"""测试不等分分割情况"""
53+
output_split_sizes = [5, 10, 15, 2]
54+
result = _gather_along_first_dim(test_tensor, mock_group, output_split_sizes)
55+
assert result.shape == (32, 16) # 5+10+15+2=32
56+
57+
@pytest.mark.parametrize("world_size", [1, 4])
58+
def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group, mock_dist, world_size):
59+
"""测试_gather_along_last_dim"""
60+
mock_dist.get_world_size.return_value = world_size
61+
62+
result = _gather_along_last_dim(test_tensor_last_dim, mock_group)
63+
64+
if world_size == 1:
65+
assert torch.equal(result, test_tensor_last_dim)
66+
else:
67+
assert result.shape == (8, 16, 32*world_size) # 8*4=32
68+
69+
@pytest.mark.parametrize("input_shape,expected_shape", [
70+
((32, 16), (8, 16)),
71+
((40, 10), (10, 10)),
72+
])
73+
def test_reduce_scatter_along_first_dim(self, mock_group, input_shape, expected_shape):
74+
input_tensor = torch.randn(*input_shape)
75+
result = _reduce_scatter_along_first_dim(input_tensor, mock_group)
76+
assert result.shape == expected_shape
77+
78+
def test_reduce_scatter_along_last_dim(self, mock_group):
79+
input_tensor = torch.randn(8, 16, 32)
80+
result = _reduce_scatter_along_last_dim(input_tensor, mock_group)
81+
assert result.shape == (8, 16, 8) # 32/4=8
82+
83+
@pytest.mark.parametrize("func,input_shape,expected_shape", [
84+
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32), (8, 16, 128)),
85+
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)),
86+
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32), (8, 16, 8)),
87+
("gather_from_sequence_parallel_region", (8, 16), (32, 16)),
88+
])
89+
def test_wrapper_functions(self, mock_group, func, input_shape, expected_shape):
90+
"""测试包装函数"""
91+
mod = importlib.import_module('vllm_ascend.distributed.tensor_parallel')
92+
globals = mod.__dict__
93+
test_func = globals[func]
94+
input_tensor = torch.randn(*input_shape)
95+
result = test_func(input_tensor, mock_group)
96+
assert result.shape == expected_shape
97+
98+
99+
@pytest.mark.parametrize("input_shape,output_shape", [
100+
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP]
101+
])
102+
def test_all_to_all_sp2hp(self, mock_group, input_shape, output_shape):
103+
input_tensor = torch.randn(*input_shape)
104+
result = all_to_all_sp2hp(input_tensor, mock_group)
105+
assert result.shape == output_shape
106+
107+
108+
@pytest.mark.parametrize("input_shape,output_shape", [
109+
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H]
110+
])
111+
def test_all_to_all_hp2sp(self, mock_group, input_shape, output_shape):
112+
input_tensor = torch.randn(*input_shape)
113+
result = all_to_all_hp2sp(input_tensor, mock_group)
114+
assert result.shape == output_shape

0 commit comments

Comments
 (0)