Skip to content

Commit 7fc1a98

Browse files
AgonixiaoxiaolixudongMengqingCao
authored
add ut for kv tansfer module (#1531)
### What this PR does / why we need it? test kv data transfer contains connect,pipe,buffer ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new added test. --------- Signed-off-by: lixudong <lixudong@cmss.chinamobile.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: lixudong <lixudong@cmss.chinamobile.com> Co-authored-by: MengqingCao <cmq0113@163.com>
1 parent aa5fa07 commit 7fc1a98

File tree

3 files changed

+362
-0
lines changed

3 files changed

+362
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import unittest
2+
import zlib
3+
from unittest.mock import MagicMock
4+
5+
import torch
6+
7+
from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer,
8+
int32_hash)
9+
10+
11+
class MockSimplePipe:
12+
13+
def __init__(self):
14+
self.cluster_id = 0
15+
self.send_tensor = MagicMock()
16+
self.recv_tensor = MagicMock()
17+
self.deallocate_buffer = MagicMock()
18+
19+
20+
class TestSimpleBuffer(unittest.TestCase):
21+
22+
def setUp(self):
23+
self.pipe = MockSimplePipe()
24+
self.buffer = SimpleBuffer(self.pipe)
25+
26+
def test_int32_hash(self):
27+
self.assertEqual(int32_hash("test"), zlib.adler32(b"test"))
28+
29+
def test_insert(self):
30+
input_tokens = torch.tensor([1, 2, 3])
31+
roi = torch.tensor([1, 0, 1])
32+
key = torch.randn(2, 3, 4, 5)
33+
value = torch.randn(2, 3, 4, 5)
34+
hidden = torch.randn(3, 6)
35+
36+
self.buffer.num_layers = 2
37+
self.buffer.num_heads = 4
38+
self.buffer.head_size = 5
39+
self.buffer.hidden_size = 6
40+
self.buffer.dtype = torch.float32
41+
42+
self.buffer.insert(input_tokens, roi, key, value, hidden, "req1")
43+
44+
self.pipe.send_tensor.assert_called()
45+
46+
def test_drop_select(self):
47+
input_tokens = torch.tensor([1, 2, 3])
48+
roi = None
49+
50+
self.buffer.num_layers = 2
51+
self.buffer.num_heads = 4
52+
self.buffer.head_size = 5
53+
self.buffer.hidden_size = 6
54+
self.buffer.dtype = torch.float32
55+
56+
self.pipe.recv_tensor.side_effect = [
57+
(MagicMock(), torch.randn(1, 2, 3 * 4 * 5)),
58+
(MagicMock(), torch.randn(1, 2, 3 * 4 * 5)),
59+
(MagicMock(), torch.randn(1, 3, 6))
60+
]
61+
62+
result = self.buffer.drop_select(input_tokens, roi, "req1")
63+
self.assertEqual(len(result), 4)
64+
self.assertIsInstance(result[0], torch.Tensor)
65+
self.assertIsInstance(result[1], torch.Tensor)
66+
self.assertIsInstance(result[2], torch.Tensor)
67+
self.assertIsNone(result[3])
68+
self.assertEqual(result[0].shape, (2, 3, 4, 5))
69+
70+
def test_close(self):
71+
self.buffer.close()
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
import torch
5+
from vllm.config import VllmConfig
6+
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
7+
8+
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
9+
from vllm_ascend.distributed.kv_transfer.simple_connector import \
10+
SimpleConnector
11+
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
12+
13+
14+
class TestSimpleConnector(unittest.TestCase):
15+
16+
def setUp(self):
17+
self.mock_pipe = MagicMock(spec=SimplePipe)
18+
self.mock_buffer = MagicMock(spec=SimpleBuffer)
19+
20+
patcher = patch(
21+
'vllm_ascend.distributed.kv_transfer.simple_buffer.SimpleBuffer')
22+
self.addCleanup(patcher.stop)
23+
self.MockSimpleBuffer = patcher.start()
24+
self.MockSimpleBuffer.return_value = self.mock_buffer
25+
26+
def _create_mock_config(self, kv_role):
27+
mock_config = MagicMock()
28+
mock_config.kv_role = "kv_producer"
29+
mock_config.kv_connector_extra_config = {
30+
"prefill_device_ips": ["127.0.0.1"],
31+
"decode_device_ips": ["127.0.0.1"],
32+
"llmdatadist_comm_port": 26000,
33+
"http_port": 8000,
34+
"proxy_ip": "127.0.0.1",
35+
"proxy_port": "8000",
36+
"port": 5500
37+
}
38+
mock_config.kv_port = 5500
39+
self.mock_config = MagicMock(spec=VllmConfig)
40+
self.mock_config.kv_transfer_config.is_kv_producer = True
41+
self.mock_config.model_config.hf_config.hidden_size = 128
42+
self.mock_config.model_config.hf_config.num_attention_heads = 8
43+
self.mock_config.model_config.hf_config.num_key_value_heads = 8
44+
self.mock_config.model_config.hf_config.qk_rope_head_dim = 16
45+
self.mock_config.model_config.hf_config.kv_lora_rank = 16
46+
self.mock_config.model_config.is_deepseek_mla = True
47+
# 模拟 parallel_config
48+
self.mock_config.parallel_config = MagicMock()
49+
self.mock_config.parallel_config.tensor_parallel_size = 1
50+
self.mock_config.parallel_config.get_num_layers.return_value = 4
51+
52+
if kv_role == "kv_producer":
53+
self.mock_config.kv_transfer_config.kv_role = "kv_producer"
54+
else:
55+
self.mock_config.kv_transfer_config.kv_role = "kv_consumer"
56+
return mock_config
57+
58+
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
59+
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
60+
@patch('llm_datadist.LLMDataDist')
61+
def test_select_init(self, mock_pipe, mock_buffer, MockLLMDataDist):
62+
"""Test select method when buffer retrieval succeeds."""
63+
connector = SimpleConnector(
64+
rank=0,
65+
local_rank=0,
66+
config=self._create_mock_config("kv_producer"))
67+
assert connector.producer_data_pipe is not None
68+
assert connector.producer_buffer is not None
69+
mock_data_dist = MockLLMDataDist.return_value
70+
mock_data_dist.init.return_value = None
71+
72+
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
73+
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
74+
@patch('llm_datadist.LLMDataDist')
75+
def test_select_select(self, mock_pipe, mock_buffer, MockLLMDataDist):
76+
77+
connector = SimpleConnector(
78+
rank=0,
79+
local_rank=0,
80+
config=self._create_mock_config("kv_consumer"))
81+
connector.consumer_data_pipe = mock_pipe
82+
connector.consumer_buffer = mock_buffer
83+
assert connector.consumer_data_pipe is not None
84+
assert connector.consumer_buffer is not None
85+
input_tokens = torch.tensor([1, 2, 3])
86+
roi = torch.tensor([True, True, True])
87+
req_id = "test_req"
88+
connector.select(input_tokens, roi, req_id)
89+
90+
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
91+
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
92+
@patch('llm_datadist.LLMDataDist')
93+
def test_insert(self, mock_pipe, mock_buffer, MockLLMDataDist):
94+
"""Test insert operation"""
95+
connector = SimpleConnector(
96+
rank=0,
97+
local_rank=0,
98+
config=self._create_mock_config("kv_producer"))
99+
100+
connector.producer_buffer = mock_buffer
101+
102+
input_tokens = torch.randint(0, 1000, (5, ))
103+
roi = torch.ones_like(input_tokens, dtype=torch.bool)
104+
keys = torch.randn(3, 5, 1, 96)
105+
values = torch.randn(3, 5, 1, 96)
106+
hidden = torch.randn(5, 768)
107+
req_id = "test_req"
108+
109+
connector.insert(input_tokens, roi, keys, values, hidden, req_id)
110+
111+
mock_buffer.insert.assert_called_once_with(input_tokens, roi, keys,
112+
values, hidden, req_id)
113+
114+
@patch.object(SimpleConnector, 'insert')
115+
@patch('torch.distributed.get_rank', return_value=0)
116+
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
117+
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
118+
@patch('llm_datadist.LLMDataDist')
119+
def test_send_kv_caches_and_hidden_states(self, mock_pipe, mock_buffer,
120+
MockLLMDataDist, mock_insert,
121+
mock_rank):
122+
"""Test sending KV caches and hidden states"""
123+
connector = SimpleConnector(
124+
rank=0,
125+
local_rank=0,
126+
config=self._create_mock_config("kv_producer"))
127+
128+
mock_model_executable = MagicMock()
129+
mock_model_executable.model.start_layer = 0
130+
mock_model_executable.model.end_layer = 3
131+
132+
mock_model_input = MagicMock(spec=ModelInputForGPUWithSamplingMetadata)
133+
mock_model_input.input_tokens = torch.randint(0, 1000, (10, ))
134+
mock_model_input.attn_metadata.seq_lens = [5, 5]
135+
mock_model_input.attn_metadata.slot_mapping = torch.randint(
136+
0, 100, (10, ))
137+
mock_model_input.attn_metadata.num_prefill_tokens = 10
138+
mock_model_input.request_ids_to_seq_ids = {"req1": [0], "req2": [1]}
139+
140+
kv_caches = [torch.randn(2, 100, 1, 96) for _ in range(3)]
141+
142+
hidden_states = torch.randn(10, 768)
143+
144+
connector.send_kv_caches_and_hidden_states(mock_model_executable,
145+
mock_model_input, kv_caches,
146+
hidden_states)
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
import torch
5+
6+
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
7+
8+
9+
class TestSimplePipe(unittest.TestCase):
10+
11+
@classmethod
12+
def _create_mock_config(self):
13+
mock_config = MagicMock()
14+
mock_config.kv_role = "kv_producer"
15+
mock_config.kv_connector_extra_config = {
16+
"prefill_device_ips": ["127.0.0.1"],
17+
"decode_device_ips": ["127.0.0.1"],
18+
"llmdatadist_comm_port": 26000,
19+
"http_port": 8000,
20+
"proxy_ip": "127.0.0.1",
21+
"proxy_port": "8000",
22+
"port": 5500
23+
}
24+
mock_config.kv_port = 5500
25+
return mock_config
26+
27+
@patch('threading.Thread')
28+
@patch('llm_datadist.LLMDataDist')
29+
def test_init_success(self, mock_thread, MockLLMDataDist):
30+
31+
mock_config = self._create_mock_config()
32+
33+
self.pipe = SimplePipe(rank=5,
34+
local_rank=0,
35+
kv_transfer_config=mock_config,
36+
hostname="127.0.0.1",
37+
port_offset=0)
38+
39+
self.pipe.router_socket.close()
40+
41+
@patch('threading.Thread')
42+
@patch('llm_datadist.LLMDataDist')
43+
def test_prepare_data_dist(self, mock_thread, MockLLMDataDist):
44+
self.pipe = SimplePipe(rank=5,
45+
local_rank=0,
46+
kv_transfer_config=self._create_mock_config(),
47+
hostname="127.0.0.1",
48+
port_offset=0)
49+
mock_data_dist = MockLLMDataDist.return_value
50+
mock_data_dist.init.return_value = None
51+
self.pipe.router_socket.close()
52+
53+
def test_init_with_invalid_kv_role(self):
54+
with self.assertRaises(NotImplementedError):
55+
mock_config = MagicMock()
56+
mock_config.kv_role = "err_role"
57+
mock_config.kv_connector_extra_config = {
58+
"prefill_device_ips": ["127.0.0.1"],
59+
"decode_device_ips": ["127.0.0.1"],
60+
"llmdatadist_comm_port": 26000,
61+
"http_port": 8000,
62+
"proxy_ip": "127.0.0.1",
63+
"proxy_port": "8000",
64+
"port": 5500
65+
}
66+
pipe = SimplePipe(rank=5,
67+
local_rank=0,
68+
kv_transfer_config=mock_config,
69+
hostname="127.0.0.1",
70+
port_offset=0)
71+
pipe.router_socket.close()
72+
73+
def test_init_with_missing_device_ips(self):
74+
with self.assertRaises(ValueError):
75+
mock_config = MagicMock()
76+
mock_config.kv_role = "kv_producer"
77+
mock_config.kv_connector_extra_config = {
78+
"llmdatadist_comm_port": 26000,
79+
"http_port": 8000,
80+
"proxy_ip": "127.0.0.1",
81+
"proxy_port": "8000",
82+
"port": 5500
83+
}
84+
pipe = SimplePipe(rank=0,
85+
local_rank=0,
86+
kv_transfer_config=mock_config,
87+
hostname="127.0.0.1",
88+
port_offset=0)
89+
pipe.router_socket.close()
90+
91+
@patch('threading.Thread')
92+
@patch('llm_datadist.LLMDataDist')
93+
def test_create_register_thread_address_is_empty(self, MockThread,
94+
MockLLMDataDist):
95+
96+
mock_config = self._create_mock_config()
97+
pipe = SimplePipe(rank=5,
98+
local_rank=0,
99+
kv_transfer_config=mock_config,
100+
hostname="127.0.0.1",
101+
port_offset=0)
102+
self.assertIsNotNone(pipe._register_thread)
103+
mock_data_dist = MockLLMDataDist.return_value
104+
mock_data_dist.init.return_value = None
105+
pipe.router_socket.close()
106+
107+
@patch('threading.Thread')
108+
@patch('llm_datadist.LLMDataDist')
109+
def test_create_register_thread_address_is_not_empty(
110+
self, MockThread, MockLLMDataDist):
111+
mock_config = MagicMock()
112+
mock_config.kv_role = "kv_producer"
113+
mock_config.kv_connector_extra_config = {
114+
"prefill_device_ips": [""],
115+
"decode_device_ips": [""],
116+
"llmdatadist_comm_port": 26000,
117+
"http_port": 8000,
118+
"proxy_ip": "127.0.0.1",
119+
"proxy_port": "8000",
120+
"port": 5500
121+
}
122+
pipe = SimplePipe(rank=5,
123+
local_rank=0,
124+
kv_transfer_config=mock_config,
125+
hostname="127.0.0.1",
126+
port_offset=0)
127+
self.assertIsNotNone(pipe._register_thread)
128+
mock_data_dist = MockLLMDataDist.return_value
129+
mock_data_dist.init.return_value = None
130+
pipe.router_socket.close()
131+
132+
@patch('vllm_ascend.distributed.kv_transfer.simple_pipe.SimplePipe')
133+
@patch('llm_datadist.LLMDataDist')
134+
def test_should_send_tensor_when_valid_input(self, MockSimplePipe,
135+
MockLLMDataDist):
136+
pipe = MockSimplePipe()
137+
tensor = torch.randn(3, 3)
138+
tensor_desc = MockLLMDataDist.CacheDesc(
139+
num_tensors=1,
140+
shape=(3, 3),
141+
data_type=MockLLMDataDist.DataType.DT_FLOAT,
142+
seq_len_dim_index=1)
143+
tensor_key = MockLLMDataDist.CacheKey(1, 0, 1)
144+
result = pipe.send_tensor(tensor, tensor_desc, tensor_key)
145+
self.assertIsNotNone(result)

0 commit comments

Comments
 (0)