Skip to content

Commit 4e29c5a

Browse files
Add ut for test_pooling_model_runner.py (#1640)
### What this PR does / why we need it? Add ut for test_pooling_model_runner.py ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? python -m unittest test_pooling_model_runner.py - vLLM version: v0.9.1 - vLLM main: vllm-project/vllm@2e610de --------- Signed-off-by: wangyanhui-cmss <wangyanhui_yewu@cmss.chinamobile.com>
1 parent 493768e commit 4e29c5a

File tree

1 file changed

+355
-0
lines changed

1 file changed

+355
-0
lines changed
Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
import torch
5+
from vllm.distributed.parallel_state import GroupCoordinator
6+
from vllm.engine.arg_utils import EngineArgs
7+
from vllm.pooling_params import PoolingParams
8+
from vllm.sequence import SequenceData, SequenceGroupMetadata
9+
10+
from vllm_ascend.worker.pooling_model_runner import (
11+
ModelInputForNPUWithPoolingMetadata, NPUPoolingModelRunner)
12+
13+
14+
class TestPoolingModelRunner(unittest.TestCase):
15+
"""Unit tests for the NPUPoolingModelRunner class."""
16+
17+
def _create_model_runner(self, model: str, *args,
18+
**kwargs) -> NPUPoolingModelRunner:
19+
engine_args = EngineArgs(model, *args, **kwargs)
20+
engine_config = engine_args.create_engine_config()
21+
model_runner = NPUPoolingModelRunner(vllm_config=engine_config, )
22+
return model_runner
23+
24+
def setUp(self):
25+
"""Initialize test fixtures and common mocks"""
26+
self.attn_backend = "npu"
27+
28+
model_runner = self._create_model_runner(
29+
"tests/ut/fake_weight",
30+
trust_remote_code=True,
31+
enable_chunked_prefill=False,
32+
)
33+
34+
self.runner = model_runner
35+
self.runner.attn_backend = self.attn_backend
36+
model_runner.model = MagicMock()
37+
self.runner = model_runner
38+
# Sample test data
39+
self.sample_tensor_dict = {"tensor1": torch.randn(3, 4)}
40+
self.sample_seq_group = [MagicMock(spec=SequenceGroupMetadata)]
41+
self.sample_finished_ids = ["req1", "req2"]
42+
43+
@patch(
44+
'vllm_ascend.worker.pooling_model_runner.ModelInputForNPUWithPoolingMetadata.from_broadcasted_tensor_dict'
45+
)
46+
def test_make_model_input_from_broadcasted_tensor_dict(
47+
self, mock_from_dict):
48+
"""Test tensor dictionary conversion to model input"""
49+
# Setup mock return
50+
expected_output = MagicMock()
51+
mock_from_dict.return_value = expected_output
52+
53+
# Execute
54+
result = self.runner.make_model_input_from_broadcasted_tensor_dict(
55+
self.sample_tensor_dict)
56+
57+
# Verify
58+
mock_from_dict.assert_called_once_with(self.sample_tensor_dict,
59+
attn_backend=self.attn_backend)
60+
self.assertEqual(result, expected_output)
61+
62+
@patch.object(NPUPoolingModelRunner, '_prepare_pooling')
63+
@patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors')
64+
def test_prepare_model_input_normal_case(self, mock_prepare_tensors,
65+
mock_prepare_pooling):
66+
"""Test normal flow of model input preparation"""
67+
# Setup mocks
68+
mock_model_input = ModelInputForNPUWithPoolingMetadata(
69+
seq_lens=[1, 2, 3])
70+
mock_prepare_tensors.return_value = mock_model_input
71+
72+
mock_pooling_metadata = MagicMock()
73+
mock_prepare_pooling.return_value = mock_pooling_metadata
74+
75+
# Execute
76+
result = self.runner.prepare_model_input(
77+
seq_group_metadata_list=self.sample_seq_group,
78+
finished_requests_ids=self.sample_finished_ids)
79+
80+
# Verify
81+
mock_prepare_tensors.assert_called_once_with(self.sample_seq_group,
82+
self.sample_finished_ids)
83+
mock_prepare_pooling.assert_called_once_with(self.sample_seq_group,
84+
mock_model_input.seq_lens)
85+
self.assertEqual(result.pooling_metadata, mock_pooling_metadata)
86+
87+
def test_prepare_model_input_null_sequence_group(self):
88+
"""Test assertion when seq_group_metadata_list is None"""
89+
with self.assertRaises(AssertionError):
90+
self.runner.prepare_model_input(
91+
seq_group_metadata_list=None,
92+
finished_requests_ids=self.sample_finished_ids)
93+
94+
@patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors')
95+
def test_prepare_model_input_null_seq_lens(self, mock_prepare_tensors):
96+
"""Test assertion when seq_lens is None in model input"""
97+
# Setup mock with None seq_lens
98+
mock_model_input = MagicMock()
99+
mock_model_input.seq_lens = None
100+
mock_prepare_tensors.return_value = mock_model_input
101+
102+
with self.assertRaises(AssertionError):
103+
self.runner.prepare_model_input(
104+
seq_group_metadata_list=self.sample_seq_group,
105+
finished_requests_ids=self.sample_finished_ids)
106+
107+
@patch.object(NPUPoolingModelRunner, '_prepare_pooling')
108+
@patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors')
109+
def test_prepare_model_input_with_virtual_engine(self,
110+
mock_prepare_tensors,
111+
mock_prepare_pooling):
112+
"""Test virtual engine parameter is properly handled"""
113+
# Setup mocks
114+
mock_model_input = ModelInputForNPUWithPoolingMetadata(
115+
seq_lens=[1, 2, 3])
116+
mock_prepare_tensors.return_value = mock_model_input
117+
118+
# Execute with virtual_engine parameter
119+
result = self.runner.prepare_model_input(
120+
seq_group_metadata_list=self.sample_seq_group,
121+
virtual_engine=1,
122+
finished_requests_ids=self.sample_finished_ids)
123+
124+
# Verify virtual_engine doesn't affect the flow
125+
self.assertIsNotNone(result)
126+
127+
@patch.object(NPUPoolingModelRunner, '_prepare_pooling')
128+
@patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors')
129+
def test_prepare_model_input_with_null_finished_ids(
130+
self, mock_prepare_tensors, mock_prepare_pooling):
131+
"""Test case when finished_requests_ids is None"""
132+
# Setup mocks
133+
mock_model_input = ModelInputForNPUWithPoolingMetadata(
134+
seq_lens=[1, 2, 3])
135+
mock_prepare_tensors.return_value = mock_model_input
136+
137+
# Execute with None finished_ids
138+
result = self.runner.prepare_model_input(
139+
seq_group_metadata_list=self.sample_seq_group,
140+
finished_requests_ids=None)
141+
142+
# Verify
143+
mock_prepare_tensors.assert_called_once_with(self.sample_seq_group,
144+
None)
145+
self.assertIsNotNone(result)
146+
147+
@patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__')
148+
def test_prepare_pooling_normal_case(self, mock_pooling_metadata):
149+
"""Test normal case with multiple sequences in group"""
150+
# Setup test data
151+
mock_pooling_metadata.return_value = None
152+
seq_data = {
153+
1: MagicMock(spec=SequenceData),
154+
2: MagicMock(spec=SequenceData)
155+
}
156+
pooling_params = MagicMock(spec=PoolingParams)
157+
seq_group = MagicMock(spec=SequenceGroupMetadata)
158+
seq_group.seq_data = seq_data
159+
seq_group.pooling_params = pooling_params
160+
161+
# Call the function
162+
self.runner._prepare_pooling([seq_group], [10, 20])
163+
164+
# Verify results
165+
mock_pooling_metadata.assert_called_once_with(seq_groups=[
166+
([1, 2], pooling_params)
167+
],
168+
seq_data=seq_data,
169+
prompt_lens=[10, 20])
170+
171+
@patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__')
172+
def test_prepare_pooling_empty_group(self, mock_pooling_metadata):
173+
"""Test case with empty sequence group"""
174+
# Setup empty group
175+
mock_pooling_metadata.return_value = None
176+
empty_seq_data: dict[int, SequenceData] = {}
177+
pooling_params = MagicMock(spec=PoolingParams)
178+
empty_group = MagicMock(spec=SequenceGroupMetadata)
179+
empty_group.seq_data = empty_seq_data
180+
empty_group.pooling_params = pooling_params
181+
182+
# Call the function
183+
self.runner._prepare_pooling([empty_group], [])
184+
185+
# Verify results
186+
mock_pooling_metadata.assert_called_once_with(seq_groups=[
187+
([], pooling_params)
188+
],
189+
seq_data={},
190+
prompt_lens=[])
191+
192+
@patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__')
193+
def test_prepare_pooling_single_sequence(self, mock_pooling_metadata):
194+
"""Test case with single sequence in group"""
195+
# Setup single sequence
196+
mock_pooling_metadata.return_value = None
197+
single_seq_data = {3: MagicMock(spec=SequenceData)}
198+
pooling_params = MagicMock(spec=PoolingParams)
199+
single_group = MagicMock(spec=SequenceGroupMetadata)
200+
single_group.seq_data = single_seq_data
201+
single_group.pooling_params = pooling_params
202+
203+
# Call the function
204+
self.runner._prepare_pooling([single_group], [5])
205+
206+
# Verify results
207+
mock_pooling_metadata.assert_called_once_with(seq_groups=[
208+
([3], pooling_params)
209+
],
210+
seq_data=single_seq_data,
211+
prompt_lens=[5])
212+
213+
@patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__')
214+
def test_prepare_pooling_multiple_groups(self, mock_pooling_metadata):
215+
"""Test case with multiple sequence groups"""
216+
# Setup multiple groups
217+
mock_pooling_metadata.return_value = None
218+
seq_data1 = {1: MagicMock(spec=SequenceData)}
219+
seq_data2 = {2: MagicMock(spec=SequenceData)}
220+
params1 = MagicMock(spec=PoolingParams)
221+
params2 = MagicMock(spec=PoolingParams)
222+
223+
group1 = MagicMock(spec=SequenceGroupMetadata)
224+
group1.seq_data = seq_data1
225+
group1.pooling_params = params1
226+
227+
group2 = MagicMock(spec=SequenceGroupMetadata)
228+
group2.seq_data = seq_data2
229+
group2.pooling_params = params2
230+
231+
# Call the function
232+
self.runner._prepare_pooling([group1, group2], [10, 20])
233+
234+
# Verify results
235+
mock_pooling_metadata.assert_called_once_with(seq_groups=[
236+
([1], params1), ([2], params2)
237+
],
238+
seq_data={
239+
**seq_data1,
240+
**seq_data2
241+
},
242+
prompt_lens=[10, 20])
243+
244+
@patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__')
245+
def test_prepare_pooling_empty_input(self, mock_pooling_metadata):
246+
"""Test case with empty input lists"""
247+
# Call the function with empty inputs
248+
mock_pooling_metadata.return_value = None
249+
self.runner._prepare_pooling([], [])
250+
251+
# Verify results
252+
mock_pooling_metadata.assert_called_once_with(seq_groups=[],
253+
seq_data={},
254+
prompt_lens=[])
255+
256+
@patch('vllm.forward_context.set_forward_context')
257+
@patch('vllm.distributed.parallel_state._PP',
258+
new_callable=lambda: MagicMock(spec=GroupCoordinator,
259+
is_last_rank=True))
260+
@patch('torch.npu.Event')
261+
@patch.object(NPUPoolingModelRunner, 'set_active_loras')
262+
@patch.object(NPUPoolingModelRunner, 'set_active_prompt_adapters')
263+
def test_execute_model_normal_flow(self, mock_set_adapters, mock_set_loras,
264+
mock_event, mock_pp, mock_set_forward):
265+
"""Test normal execution path with all dependencies mocked"""
266+
267+
# Setup model input mock
268+
mock_input = MagicMock()
269+
mock_input.input_tokens = torch.tensor([1])
270+
mock_input.input_positions = torch.tensor([0])
271+
mock_input.multi_modal_kwargs = {}
272+
self.runner.is_driver_worker = True
273+
# Execute
274+
self.runner.execute_model(model_input=mock_input,
275+
kv_caches=[],
276+
num_steps=1)
277+
278+
# Verify core calls
279+
self.runner.model.pooler.assert_called_once()
280+
281+
@patch('vllm.forward_context.set_forward_context')
282+
def test_execute_model_invalid_steps(self, mock_set_forward):
283+
"""Test ValueError when num_steps != 1"""
284+
with self.assertRaises(ValueError):
285+
self.runner.execute_model(model_input=MagicMock(),
286+
kv_caches=[],
287+
num_steps=2)
288+
mock_set_forward.assert_not_called()
289+
290+
@patch('vllm.forward_context.set_forward_context')
291+
@patch('vllm.distributed.parallel_state._PP',
292+
new_callable=lambda: MagicMock(spec=GroupCoordinator,
293+
is_last_rank=False))
294+
@patch('torch.npu.Event')
295+
def test_execute_model_perf_monitoring(self, mock_event, mock_pp,
296+
mock_set_forward):
297+
"""Test performance monitoring with timing mocks"""
298+
# Setup mocks
299+
300+
mock_event.return_value.elapsed_time.return_value = 15.0
301+
self.runner.observability_config = MagicMock(
302+
collect_model_forward_time=True)
303+
304+
# Execute
305+
self.runner.execute_model(model_input=MagicMock(
306+
input_tokens=torch.tensor([1]),
307+
input_positions=torch.tensor([0]),
308+
multi_modal_kwargs={}),
309+
kv_caches=[],
310+
num_steps=1)
311+
312+
# Verify timing calls
313+
self.assertEqual(mock_event.call_count, 2)
314+
315+
@patch('vllm.forward_context.set_forward_context')
316+
@patch.object(NPUPoolingModelRunner, 'set_active_loras')
317+
@patch('vllm.distributed.parallel_state._PP',
318+
new_callable=lambda: MagicMock(spec=GroupCoordinator,
319+
is_last_rank=False))
320+
def test_execute_model_lora_config(self, mock_pp, set_active_loras,
321+
mock_set_forward):
322+
"""Test LoRA configuration handling"""
323+
# Setup
324+
325+
self.runner.lora_config = True
326+
mock_input = MagicMock()
327+
mock_input.lora_requests = ["req1"]
328+
mock_input.lora_mapping = {"map": 1}
329+
330+
# Execute
331+
self.runner.execute_model(model_input=mock_input,
332+
kv_caches=[],
333+
num_steps=1)
334+
335+
# Verify LoRA call
336+
set_active_loras.assert_called_once_with(["req1"], {"map": 1})
337+
338+
@patch('vllm.forward_context.set_forward_context')
339+
@patch('vllm.distributed.parallel_state._PP',
340+
new_callable=lambda: MagicMock(spec=GroupCoordinator,
341+
is_last_rank=False))
342+
def test_execute_model_not_last_rank(self, mock_pp, mock_set_forward):
343+
"""Test behavior when not the last pipeline rank"""
344+
# Setup
345+
346+
# Execute
347+
self.runner.execute_model(model_input=MagicMock(
348+
input_tokens=torch.tensor([1]),
349+
input_positions=torch.tensor([0]),
350+
multi_modal_kwargs={}),
351+
kv_caches=[],
352+
num_steps=1)
353+
354+
# Verify pooler not called
355+
self.runner.model.pooler.assert_not_called()

0 commit comments

Comments
 (0)