Skip to content

Commit e5eea64

Browse files
[CI/UT] Add ut for parallel_state.py (#1460)
### What this PR does / why we need it? Add ut for parallel_state.py ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? python -m unittest test_parallel_state.py --------- Signed-off-by: wangyanhui-cmss <wangyanhui_yewu@cmss.chinamobile.com>
1 parent 4e2daf5 commit e5eea64

File tree

1 file changed

+193
-0
lines changed

1 file changed

+193
-0
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
import pytest
5+
from vllm.distributed.parallel_state import GroupCoordinator
6+
7+
import vllm_ascend
8+
from vllm_ascend.distributed.parallel_state import (
9+
destory_ascend_model_parallel, get_ep_group, get_etp_group,
10+
init_ascend_model_parallel, model_parallel_initialized)
11+
12+
13+
class TestParallelState(unittest.TestCase):
14+
15+
@patch('vllm_ascend.distributed.parallel_state._EP',
16+
new_callable=lambda: MagicMock(spec=GroupCoordinator))
17+
def test_get_ep_group_when_initialized(self, mock_ep):
18+
# Act
19+
result = get_ep_group()
20+
21+
# Assert
22+
assert isinstance(result, GroupCoordinator)
23+
24+
@patch('vllm_ascend.distributed.parallel_state._EP', None)
25+
def test_get_ep_group_when_not_initialized(self):
26+
# Act & Assert
27+
with pytest.raises(AssertionError) as excinfo:
28+
get_ep_group()
29+
assert "expert model parallel group is not initialized" in str(
30+
excinfo.value)
31+
32+
@patch('vllm_ascend.distributed.parallel_state._ETP',
33+
new_callable=lambda: MagicMock(spec=GroupCoordinator))
34+
def test_get_etp_group_when_initialized(self, mock_etp):
35+
# Act
36+
result = get_etp_group()
37+
38+
# Assert
39+
assert isinstance(result, GroupCoordinator)
40+
41+
@patch('vllm_ascend.distributed.parallel_state._ETP', None)
42+
def test_get_etp_group_when_not_initialized(self):
43+
# Act & Assert
44+
with pytest.raises(AssertionError) as excinfo:
45+
get_etp_group()
46+
assert "expert tensor parallel group is not initialized" in str(
47+
excinfo.value)
48+
49+
@patch('vllm_ascend.distributed.parallel_state._ETP', None)
50+
@patch('vllm_ascend.distributed.parallel_state._EP', None)
51+
def test_model_parallel_initialized_when_both_none(self):
52+
# Act & Assert
53+
assert not model_parallel_initialized()
54+
55+
@patch('vllm_ascend.distributed.parallel_state._ETP',
56+
new_callable=lambda: MagicMock(spec=GroupCoordinator))
57+
@patch('vllm_ascend.distributed.parallel_state._EP', None)
58+
def test_model_parallel_initialized_when_ep_none(self, mock_etp):
59+
# Act & Assert
60+
assert not model_parallel_initialized()
61+
62+
@patch('vllm_ascend.distributed.parallel_state._ETP', None)
63+
@patch('vllm_ascend.distributed.parallel_state._EP',
64+
new_callable=lambda: MagicMock(spec=GroupCoordinator))
65+
def test_model_parallel_initialized_when_etp_none(self, mock_ep):
66+
# Act & Assert
67+
assert not model_parallel_initialized()
68+
69+
@patch('vllm_ascend.distributed.parallel_state._ETP',
70+
new_callable=lambda: MagicMock(spec=GroupCoordinator))
71+
@patch('vllm_ascend.distributed.parallel_state._EP',
72+
new_callable=lambda: MagicMock(spec=GroupCoordinator))
73+
def test_model_parallel_initialized_when_etp_initialized(
74+
self, mock_ep, mock_etp):
75+
# Act & Assert
76+
assert model_parallel_initialized()
77+
78+
@patch('vllm_ascend.distributed.parallel_state._ETP',
79+
new_callable=lambda: MagicMock(spec=GroupCoordinator))
80+
@patch('vllm_ascend.distributed.parallel_state._EP',
81+
new_callable=lambda: MagicMock(spec=GroupCoordinator))
82+
def test_destroy_when_both_exist(self, mock_ep, mock_etp):
83+
# Act
84+
destory_ascend_model_parallel()
85+
# Assert
86+
mock_ep.destroy.assert_called_once()
87+
mock_etp.destroy.assert_called_once()
88+
assert vllm_ascend.distributed.parallel_state._ETP is None
89+
assert vllm_ascend.distributed.parallel_state._EP is None
90+
91+
@patch('vllm_ascend.distributed.parallel_state._ETP', None)
92+
@patch('vllm_ascend.distributed.parallel_state._EP',
93+
new_callable=lambda: MagicMock())
94+
def test_destory_ascend_model_parallel_when_etp_none(self, mock_ep):
95+
# Act
96+
destory_ascend_model_parallel()
97+
# Assert
98+
mock_ep.destroy.assert_called_once()
99+
assert vllm_ascend.distributed.parallel_state._EP is None
100+
assert vllm_ascend.distributed.parallel_state._ETP is None
101+
102+
@patch('vllm_ascend.distributed.parallel_state._ETP',
103+
new_callable=lambda: MagicMock())
104+
@patch('vllm_ascend.distributed.parallel_state._EP', None)
105+
def test_destory_ascend_model_parallel_when_ep_none(self, mock_etp):
106+
# Act
107+
destory_ascend_model_parallel()
108+
# Assert
109+
mock_etp.destroy.assert_called_once()
110+
assert vllm_ascend.distributed.parallel_state._ETP is None
111+
assert vllm_ascend.distributed.parallel_state._EP is None
112+
113+
@patch('vllm_ascend.distributed.parallel_state._ETP', None)
114+
@patch('vllm_ascend.distributed.parallel_state._EP', None)
115+
def test_destory_ascend_model_parallel_when_both_none(self):
116+
# Act
117+
destory_ascend_model_parallel()
118+
# Assert
119+
assert vllm_ascend.distributed.parallel_state._ETP is None
120+
assert vllm_ascend.distributed.parallel_state._EP is None
121+
122+
@patch('torch.distributed.is_initialized', return_value=True)
123+
@patch('torch.distributed.get_world_size', return_value=8)
124+
@patch('vllm_ascend.distributed.parallel_state.get_world_group',
125+
return_value=MagicMock(device_group='npu:0', local_rank=0))
126+
@patch('torch.distributed.get_backend', return_value='hccl')
127+
@patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group')
128+
@patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized',
129+
return_value=False)
130+
def test_init_ascend_model_parallel_normal_case(
131+
self, mock_mp_init, mock_init_group, mock_get_backend,
132+
mock_world_group, mock_get_world_size, mock_is_init):
133+
"""Test normal initialization with default parameters"""
134+
# Act
135+
init_ascend_model_parallel()
136+
# Assert
137+
mock_init_group.assert_any_call([[0, 1, 2, 3, 4, 5, 6, 7]],
138+
0,
139+
'hccl',
140+
group_name="ep")
141+
mock_init_group.assert_any_call([[0]], 0, 'hccl', group_name="etp")
142+
self.assertIsNotNone(vllm_ascend.distributed.parallel_state._EP)
143+
self.assertIsNotNone(vllm_ascend.distributed.parallel_state._ETP)
144+
145+
@patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized',
146+
return_value=True)
147+
def test_init_ascend_model_parallel_skip_if_initialized(
148+
self, mock_mp_init):
149+
"""Test skipping when model parallel already initialized"""
150+
with patch.object(vllm_ascend.distributed.parallel_state,
151+
'_EP') as mock_ep, patch.object(
152+
vllm_ascend.distributed.parallel_state,
153+
'_ETP') as mock_etp:
154+
# Act
155+
init_ascend_model_parallel()
156+
# Assert
157+
mock_ep.assert_not_called()
158+
mock_etp.assert_not_called()
159+
160+
@patch('torch.distributed.is_initialized', return_value=False)
161+
def test_init_ascend_model_parallel_assert_dist_not_init(
162+
self, mock_is_init):
163+
"""Test assertion when distributed not initialized"""
164+
# Act & Assert
165+
with self.assertRaises(AssertionError):
166+
init_ascend_model_parallel()
167+
168+
@patch('torch.distributed.is_initialized', return_value=True)
169+
@patch('torch.distributed.get_world_size', return_value=8)
170+
@patch('vllm_ascend.distributed.parallel_state.get_world_group',
171+
return_value=MagicMock(device_group='npu:0', local_rank=1))
172+
@patch('torch.distributed.get_backend', return_value='hccl')
173+
@patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group')
174+
@patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized',
175+
return_value=False)
176+
def test_init_ascend_model_parallel_custom_params(
177+
self, mock_mp_init, mock_init_group, mock_get_backend,
178+
mock_world_group, mock_get_world_size, mock_is_init):
179+
"""Test initialization with custom parallel sizes"""
180+
# Act
181+
init_ascend_model_parallel(expert_parallel_size=2,
182+
expert_tensor_parallel_size=4,
183+
world_size=8,
184+
backend='hccl')
185+
#Assert
186+
mock_init_group.assert_any_call([[0, 4], [1, 5], [2, 6], [3, 7]],
187+
1,
188+
'hccl',
189+
group_name="ep")
190+
mock_init_group.assert_any_call([[0, 1, 2, 3], [4, 5, 6, 7]],
191+
1,
192+
'hccl',
193+
group_name="etp")

0 commit comments

Comments
 (0)