Skip to content

Commit 834babe

Browse files
committed
fix rope ut
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 1d5ec90 commit 834babe

File tree

1 file changed

+123
-104
lines changed

1 file changed

+123
-104
lines changed

tests/ut/ops/test_rotary_embedding.py

Lines changed: 123 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
import math
22
import unittest
3-
from unittest.mock import MagicMock, patch
3+
from unittest.mock import MagicMock, PropertyMock, patch
44

55
import torch
6+
from vllm.model_executor.layers.rotary_embedding import (
7+
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
68

79
from tests.ut.base import TestBase
8-
from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled,
9-
native_rope_deepseek_forward,
10-
rope_forward_oot, rotate_half,
11-
yarn_find_correction_dim,
12-
yarn_get_mscale)
10+
from vllm_ascend.ops.rotary_embedding import custom_rotary_embedding_enabled
1311

1412

1513
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
@@ -67,22 +65,28 @@ def test_custom_rotary_embedding_enabled(self):
6765
self.assertFalse(result)
6866

6967

70-
class TestRopeForwardOot(unittest.TestCase):
68+
class TestAscendRotaryEmbedding(unittest.TestCase):
7169

7270
def setUp(self):
7371
# Common setup for tests
7472
self.positions = torch.tensor([1, 2, 3])
75-
self.query = torch.randn(3, 4, dtype=torch.float16)
76-
self.key = torch.randn(3, 4, dtype=torch.float16)
73+
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
74+
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
7775
self.head_size = 32
78-
self.cos_sin_cache = torch.randn(3, 4)
76+
self.rotary_dim = self.head_size
77+
self.max_position = 16
78+
self.rope_theta = 10000
79+
self.is_neox_style = True
80+
self.cos_sin_cache = torch.randn(3, 1, 32)
81+
self.layer = RotaryEmbedding(self.head_size, self.rotary_dim,
82+
self.max_position, self.rope_theta,
83+
self.is_neox_style, torch.float16)
7984

8085
# Mock self object for rope_forward_oot
8186
self.mock_self = MagicMock()
8287
self.mock_self.head_size = self.head_size
8388
self.mock_self.cos_sin_cache = self.cos_sin_cache
84-
self.mock_self.is_neox_style = True
85-
self.mock_self.forward_native.return_value = (self.query, self.key)
89+
self.mock_self.is_neox_style = self.is_neox_style
8690

8791
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
8892
def test_rope_forward_oot_torchair_enabled_base(self,
@@ -91,12 +95,14 @@ def test_rope_forward_oot_torchair_enabled_base(self,
9195
mock_config = MagicMock()
9296
mock_config.torchair_graph_config.enabled = True
9397
mock_get_ascend_config.return_value = mock_config
94-
95-
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
96-
self.query, self.key)
97-
98-
self.mock_self.forward_native.assert_called_once_with(
99-
self.positions, self.query, self.key, None)
98+
with patch.object(self.layer,
99+
"forward_native",
100+
return_value=(self.query,
101+
self.key)) as mock_forward_native:
102+
result_q, result_k = self.layer.forward(self.positions, self.query,
103+
self.key)
104+
105+
mock_forward_native.assert_called_once()
100106
self.assertTrue(torch.equal(result_q, self.query))
101107
self.assertTrue(torch.equal(result_k, self.key))
102108

@@ -117,9 +123,10 @@ def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
117123

118124
mock__c.rotary_embedding.return_value = self.query, self.key
119125

120-
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
121-
self.query, self.key)
126+
result_q, result_k = self.layer.forward(self.positions, self.query,
127+
self.key)
122128

129+
mock__c.rotary_embedding.assert_called_once()
123130
self.assertEqual(result_q.shape, self.query.shape)
124131
self.assertEqual(result_k.shape, self.key.shape)
125132

@@ -138,8 +145,9 @@ def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
138145
non_contig_query = self.query.transpose(0, 1)
139146
non_contig_key = self.key.transpose(0, 1)
140147

141-
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
142-
non_contig_query, non_contig_key)
148+
result_q, result_k = self.layer.forward(self.positions,
149+
non_contig_query,
150+
non_contig_key)
143151

144152
mock_npu_rotary.assert_called_once()
145153
self.assertEqual(result_q.shape, non_contig_query.shape)
@@ -154,8 +162,7 @@ def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
154162
# Test that NotImplementedError is raised when offsets is provided
155163
offsets = torch.tensor([1, 2, 3])
156164
with self.assertRaises(NotImplementedError):
157-
rope_forward_oot(self.mock_self, self.positions, self.query,
158-
self.key, offsets)
165+
self.layer.forward(self.positions, self.query, self.key, offsets)
159166

160167
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
161168
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
@@ -169,11 +176,10 @@ def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
169176
mock_get_ascend_config.return_value = mock_config
170177

171178
# Test neox_style override
172-
result_q, result_k = rope_forward_oot(self.mock_self,
173-
self.positions,
174-
self.query,
175-
self.key,
176-
is_neox_style_override=False)
179+
result_q, result_k = self.layer.forward(self.positions,
180+
self.query,
181+
self.key,
182+
is_neox_style_override=False)
177183

178184
# Check that neox_style=False was passed to the NPU function
179185
args, kwargs = mock_npu_rotary.call_args
@@ -191,98 +197,108 @@ def __init__(self, max_seq_len=2048, is_neox_style=True):
191197
self.base = 1
192198

193199

194-
class TestNativeRopeDeepseekForward(TestBase):
200+
class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
195201

196-
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
197-
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
198-
module = MockRopeModule()
199-
positions = torch.tensor([1, 2, 3])
200-
query = torch.randn(1, 8, 128)
201-
key = torch.randn(1, 8, 128)
202-
203-
mock_rope_forward_oot.return_value = (query, key)
204-
205-
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
206-
key)
207-
208-
assert q_pe.shape == query.shape
209-
assert k_pe.shape == key.shape
202+
def setUp(self):
203+
# Common setup for tests
204+
self.positions = torch.tensor([1, 2, 3])
205+
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
206+
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
207+
self.head_size = 32
208+
self.rotary_dim = self.head_size
209+
self.max_position = 16
210+
self.rope_theta = 10000
211+
self.is_neox_style = True
212+
self.scaling_factor = 1
213+
self.layer = None
214+
215+
def _create_layer(self):
216+
self.layer = DeepseekScalingRotaryEmbedding(
217+
self.head_size, self.rotary_dim, self.max_position,
218+
self.rope_theta, self.is_neox_style, self.scaling_factor,
219+
torch.float16)
220+
return self.layer
221+
222+
@patch("vllm.model_executor.layers.rotary_embedding.current_platform",
223+
new_callable=PropertyMock)
224+
def test_native_rope_deepseek_forward_base(self, mock_current_platform):
225+
mock_current_platform.device_type = torch.device("cpu")
226+
self.layer = self._create_layer()
227+
with patch("vllm_ascend.ops.rotary_embedding.rope_forward_oot",
228+
return_value=(self.query,
229+
self.key)) as mock_rope_forward_oot:
230+
q_pe, k_pe = self.layer.forward(self.positions, self.query,
231+
self.key)
232+
mock_rope_forward_oot.assert_called_once()
233+
assert q_pe.shape == self.query.shape
234+
assert k_pe.shape == self.key.shape
210235

211-
@patch('vllm_ascend.ops.rotary_embedding._set_cos_sin_cache')
212236
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
237+
@patch("vllm.model_executor.layers.rotary_embedding.current_platform",
238+
new_callable=PropertyMock)
213239
def test_native_rope_deepseek_forward_cache_handling(
214-
self, mock_rope_forward_oot, mock_set_cache):
240+
self, mock_current_platform, mock_rope_forward_oot):
241+
mock_current_platform.device_type = torch.device("cpu")
242+
self.layer = self._create_layer()
243+
self.layer.max_seq_len = 1024
215244
# Test cache situation is true
216-
module = MockRopeModule(max_seq_len=1024)
217-
positions = torch.tensor([1, 2, 3])
218-
query = torch.randn(1, 8, 128)
219-
key = torch.randn(1, 8, 128)
245+
with patch.object(self.layer, "_set_cos_sin_cache") as mock_set_cache:
246+
mock_rope_forward_oot.return_value = (self.query, self.key)
220247

221-
mock_rope_forward_oot.return_value = (query, key)
222-
223-
q_pe, k_pe = native_rope_deepseek_forward(module,
224-
positions,
225-
query,
226-
key,
227-
max_seq_len=2048)
228-
229-
assert q_pe.shape == query.shape
230-
assert k_pe.shape == key.shape
248+
q_pe, k_pe = self.layer.forward(self.positions,
249+
self.query,
250+
self.key,
251+
max_seq_len=2048)
252+
mock_set_cache.assert_called_once()
253+
assert q_pe.shape == self.query.shape
254+
assert k_pe.shape == self.key.shape
231255

232256
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
257+
@patch("vllm.model_executor.layers.rotary_embedding.current_platform",
258+
new_callable=PropertyMock)
233259
def test_native_rope_deepseek_forward_key_reshaping(
234-
self, mock_rope_forward_oot):
235-
module = MockRopeModule()
236-
positions = torch.tensor([1, 2, 3])
237-
query = torch.randn(1, 8, 128)
238-
key = torch.randn(1, 128)
260+
self, mock_current_platform, mock_rope_forward_oot):
261+
mock_current_platform.device_type = torch.device("cpu")
262+
self.layer = self._create_layer()
239263

240-
mock_rope_forward_oot.return_value = (query, key)
264+
key = torch.randn(1, 32)
241265

242-
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
243-
key)
266+
mock_rope_forward_oot.return_value = (self.query, key)
244267

245-
assert q_pe.shape == query.shape
246-
assert k_pe.shape == (1, 128)
268+
q_pe, k_pe = self.layer.forward(self.positions, self.query, key)
269+
mock_rope_forward_oot.assert_called_once()
270+
assert q_pe.shape == self.query.shape
271+
assert k_pe.shape == key.shape
247272

248273
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
274+
@patch("vllm.model_executor.layers.rotary_embedding.current_platform",
275+
new_callable=PropertyMock)
249276
def test_native_rope_deepseek_forward_non_neox_style(
250-
self, mock_rope_forward_oot):
251-
module = MockRopeModule(is_neox_style=False)
252-
positions = torch.tensor([1, 2, 3])
253-
query = torch.randn(1, 8, 128)
254-
key = torch.randn(1, 8, 128)
255-
256-
mock_rope_forward_oot.return_value = (query, key)
277+
self, mock_current_platform, mock_rope_forward_oot):
278+
mock_current_platform.device_type = torch.device("cpu")
279+
self.layer = self._create_layer()
257280

258-
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
259-
key)
260-
261-
assert q_pe.shape == query.shape
262-
assert k_pe.shape == key.shape
281+
mock_rope_forward_oot.return_value = (self.query, self.key)
263282

283+
q_pe, k_pe = self.layer.forward(self.positions, self.query, self.key)
264284

265-
class TestRotateHalf(unittest.TestCase):
266-
267-
def test_rotate_half_even_dim(self):
268-
# Test with even dimension
269-
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
270-
expected = torch.tensor([-3.0, -4.0, 1.0, 2.0])
271-
result = rotate_half(x)
272-
self.assertTrue(torch.allclose(result, expected))
285+
mock_rope_forward_oot.assert_called_once()
286+
assert q_pe.shape == self.query.shape
287+
assert k_pe.shape == self.key.shape
273288

274-
275-
class TestYarnFindCorrectionDim(unittest.TestCase):
276-
277-
def test_basic_case(self):
289+
@patch("vllm.model_executor.layers.rotary_embedding.current_platform",
290+
new_callable=PropertyMock)
291+
def test_basic_case(self, mock_current_platform):
278292
# Test with standard values
293+
mock_current_platform.device_type = torch.device("cpu")
294+
self.layer = self._create_layer()
279295
num_rotations = 100
280296
dim = 512
281297
base = 10000
282298
max_position_embeddings = 2048
283299

284-
result = yarn_find_correction_dim(num_rotations, dim, base,
285-
max_position_embeddings)
300+
result = self.layer._yarn_find_correction_dim(num_rotations, dim, base,
301+
max_position_embeddings)
286302

287303
# Calculate expected value manually
288304
expected = (dim * torch.log(
@@ -292,22 +308,25 @@ def test_basic_case(self):
292308

293309
self.assertTrue(torch.allclose(result, expected))
294310

311+
@patch("vllm.model_executor.layers.rotary_embedding.current_platform",
312+
new_callable=PropertyMock)
313+
def test_yarn_get_mscale(self, mock_current_platform):
314+
mock_current_platform.device_type = torch.device("cpu")
315+
self.layer = self._create_layer()
295316

296-
class TestYarnGetMscale(unittest.TestCase):
297-
298-
def test_scale_less_than_or_equal_1(self):
299-
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
300-
self.assertEqual(yarn_get_mscale(scale=1.0), 1.0)
301-
self.assertEqual(yarn_get_mscale(scale=0.999), 1.0)
317+
# test_scale_less_than_or_equal_1
318+
self.assertEqual(self.layer._yarn_get_mscale(scale=0.5), 1.0)
319+
self.assertEqual(self.layer._yarn_get_mscale(scale=1.0), 1.0)
320+
self.assertEqual(self.layer._yarn_get_mscale(scale=0.999), 1.0)
302321

303-
def test_scale_greater_than_1(self):
322+
# test_scale_greater_than_1:
304323
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
305324
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
306325
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
307326
(math.e, 1.0, 1.0 + 0.1)]
308327

309328
for scale, mscale, expected in test_cases:
310-
result = yarn_get_mscale(scale, mscale)
329+
result = self.layer._yarn_get_mscale(scale, mscale)
311330
self.assertAlmostEqual(
312331
result,
313332
expected,

0 commit comments

Comments
 (0)