1
1
import math
2
2
import unittest
3
- from unittest .mock import MagicMock , patch
3
+ from unittest .mock import MagicMock , PropertyMock , patch
4
4
5
5
import torch
6
+ from vllm .model_executor .layers .rotary_embedding import (
7
+ DeepseekScalingRotaryEmbedding , RotaryEmbedding )
6
8
7
9
from tests .ut .base import TestBase
8
- from vllm_ascend .ops .rotary_embedding import (custom_rotary_embedding_enabled ,
9
- native_rope_deepseek_forward ,
10
- rope_forward_oot , rotate_half ,
11
- yarn_find_correction_dim ,
12
- yarn_get_mscale )
10
+ from vllm_ascend .ops .rotary_embedding import custom_rotary_embedding_enabled
13
11
14
12
15
13
class TestCustomRotaryEmbeddingEnabled (unittest .TestCase ):
@@ -67,22 +65,28 @@ def test_custom_rotary_embedding_enabled(self):
67
65
self .assertFalse (result )
68
66
69
67
70
- class TestRopeForwardOot (unittest .TestCase ):
68
+ class TestAscendRotaryEmbedding (unittest .TestCase ):
71
69
72
70
def setUp (self ):
73
71
# Common setup for tests
74
72
self .positions = torch .tensor ([1 , 2 , 3 ])
75
- self .query = torch .randn (3 , 4 , dtype = torch .float16 )
76
- self .key = torch .randn (3 , 4 , dtype = torch .float16 )
73
+ self .query = torch .randn (3 , 1 , 32 , dtype = torch .float16 )
74
+ self .key = torch .randn (3 , 1 , 32 , dtype = torch .float16 )
77
75
self .head_size = 32
78
- self .cos_sin_cache = torch .randn (3 , 4 )
76
+ self .rotary_dim = self .head_size
77
+ self .max_position = 16
78
+ self .rope_theta = 10000
79
+ self .is_neox_style = True
80
+ self .cos_sin_cache = torch .randn (3 , 1 , 32 )
81
+ self .layer = RotaryEmbedding (self .head_size , self .rotary_dim ,
82
+ self .max_position , self .rope_theta ,
83
+ self .is_neox_style , torch .float16 )
79
84
80
85
# Mock self object for rope_forward_oot
81
86
self .mock_self = MagicMock ()
82
87
self .mock_self .head_size = self .head_size
83
88
self .mock_self .cos_sin_cache = self .cos_sin_cache
84
- self .mock_self .is_neox_style = True
85
- self .mock_self .forward_native .return_value = (self .query , self .key )
89
+ self .mock_self .is_neox_style = self .is_neox_style
86
90
87
91
@patch ('vllm_ascend.ops.rotary_embedding.get_ascend_config' )
88
92
def test_rope_forward_oot_torchair_enabled_base (self ,
@@ -91,12 +95,14 @@ def test_rope_forward_oot_torchair_enabled_base(self,
91
95
mock_config = MagicMock ()
92
96
mock_config .torchair_graph_config .enabled = True
93
97
mock_get_ascend_config .return_value = mock_config
94
-
95
- result_q , result_k = rope_forward_oot (self .mock_self , self .positions ,
96
- self .query , self .key )
97
-
98
- self .mock_self .forward_native .assert_called_once_with (
99
- self .positions , self .query , self .key , None )
98
+ with patch .object (self .layer ,
99
+ "forward_native" ,
100
+ return_value = (self .query ,
101
+ self .key )) as mock_forward_native :
102
+ result_q , result_k = self .layer .forward (self .positions , self .query ,
103
+ self .key )
104
+
105
+ mock_forward_native .assert_called_once ()
100
106
self .assertTrue (torch .equal (result_q , self .query ))
101
107
self .assertTrue (torch .equal (result_k , self .key ))
102
108
@@ -117,9 +123,10 @@ def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
117
123
118
124
mock__c .rotary_embedding .return_value = self .query , self .key
119
125
120
- result_q , result_k = rope_forward_oot (self .mock_self , self .positions ,
121
- self . query , self .key )
126
+ result_q , result_k = self . layer . forward (self .positions , self .query ,
127
+ self .key )
122
128
129
+ mock__c .rotary_embedding .assert_called_once ()
123
130
self .assertEqual (result_q .shape , self .query .shape )
124
131
self .assertEqual (result_k .shape , self .key .shape )
125
132
@@ -138,8 +145,9 @@ def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
138
145
non_contig_query = self .query .transpose (0 , 1 )
139
146
non_contig_key = self .key .transpose (0 , 1 )
140
147
141
- result_q , result_k = rope_forward_oot (self .mock_self , self .positions ,
142
- non_contig_query , non_contig_key )
148
+ result_q , result_k = self .layer .forward (self .positions ,
149
+ non_contig_query ,
150
+ non_contig_key )
143
151
144
152
mock_npu_rotary .assert_called_once ()
145
153
self .assertEqual (result_q .shape , non_contig_query .shape )
@@ -154,8 +162,7 @@ def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
154
162
# Test that NotImplementedError is raised when offsets is provided
155
163
offsets = torch .tensor ([1 , 2 , 3 ])
156
164
with self .assertRaises (NotImplementedError ):
157
- rope_forward_oot (self .mock_self , self .positions , self .query ,
158
- self .key , offsets )
165
+ self .layer .forward (self .positions , self .query , self .key , offsets )
159
166
160
167
@patch ('vllm_ascend.ops.rotary_embedding.get_ascend_config' )
161
168
@patch ('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled' ,
@@ -169,11 +176,10 @@ def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
169
176
mock_get_ascend_config .return_value = mock_config
170
177
171
178
# Test neox_style override
172
- result_q , result_k = rope_forward_oot (self .mock_self ,
173
- self .positions ,
174
- self .query ,
175
- self .key ,
176
- is_neox_style_override = False )
179
+ result_q , result_k = self .layer .forward (self .positions ,
180
+ self .query ,
181
+ self .key ,
182
+ is_neox_style_override = False )
177
183
178
184
# Check that neox_style=False was passed to the NPU function
179
185
args , kwargs = mock_npu_rotary .call_args
@@ -191,98 +197,108 @@ def __init__(self, max_seq_len=2048, is_neox_style=True):
191
197
self .base = 1
192
198
193
199
194
- class TestNativeRopeDeepseekForward (TestBase ):
200
+ class TestAscendDeepseekScalingRotaryEmbedding (TestBase ):
195
201
196
- @patch ('vllm_ascend.ops.rotary_embedding.rope_forward_oot' )
197
- def test_native_rope_deepseek_forward_base (self , mock_rope_forward_oot ):
198
- module = MockRopeModule ()
199
- positions = torch .tensor ([1 , 2 , 3 ])
200
- query = torch .randn (1 , 8 , 128 )
201
- key = torch .randn (1 , 8 , 128 )
202
-
203
- mock_rope_forward_oot .return_value = (query , key )
204
-
205
- q_pe , k_pe = native_rope_deepseek_forward (module , positions , query ,
206
- key )
207
-
208
- assert q_pe .shape == query .shape
209
- assert k_pe .shape == key .shape
202
+ def setUp (self ):
203
+ # Common setup for tests
204
+ self .positions = torch .tensor ([1 , 2 , 3 ])
205
+ self .query = torch .randn (3 , 1 , 32 , dtype = torch .float16 )
206
+ self .key = torch .randn (3 , 1 , 32 , dtype = torch .float16 )
207
+ self .head_size = 32
208
+ self .rotary_dim = self .head_size
209
+ self .max_position = 16
210
+ self .rope_theta = 10000
211
+ self .is_neox_style = True
212
+ self .scaling_factor = 1
213
+ self .layer = None
214
+
215
+ def _create_layer (self ):
216
+ self .layer = DeepseekScalingRotaryEmbedding (
217
+ self .head_size , self .rotary_dim , self .max_position ,
218
+ self .rope_theta , self .is_neox_style , self .scaling_factor ,
219
+ torch .float16 )
220
+ return self .layer
221
+
222
+ @patch ("vllm.model_executor.layers.rotary_embedding.current_platform" ,
223
+ new_callable = PropertyMock )
224
+ def test_native_rope_deepseek_forward_base (self , mock_current_platform ):
225
+ mock_current_platform .device_type = torch .device ("cpu" )
226
+ self .layer = self ._create_layer ()
227
+ with patch ("vllm_ascend.ops.rotary_embedding.rope_forward_oot" ,
228
+ return_value = (self .query ,
229
+ self .key )) as mock_rope_forward_oot :
230
+ q_pe , k_pe = self .layer .forward (self .positions , self .query ,
231
+ self .key )
232
+ mock_rope_forward_oot .assert_called_once ()
233
+ assert q_pe .shape == self .query .shape
234
+ assert k_pe .shape == self .key .shape
210
235
211
- @patch ('vllm_ascend.ops.rotary_embedding._set_cos_sin_cache' )
212
236
@patch ('vllm_ascend.ops.rotary_embedding.rope_forward_oot' )
237
+ @patch ("vllm.model_executor.layers.rotary_embedding.current_platform" ,
238
+ new_callable = PropertyMock )
213
239
def test_native_rope_deepseek_forward_cache_handling (
214
- self , mock_rope_forward_oot , mock_set_cache ):
240
+ self , mock_current_platform , mock_rope_forward_oot ):
241
+ mock_current_platform .device_type = torch .device ("cpu" )
242
+ self .layer = self ._create_layer ()
243
+ self .layer .max_seq_len = 1024
215
244
# Test cache situation is true
216
- module = MockRopeModule (max_seq_len = 1024 )
217
- positions = torch .tensor ([1 , 2 , 3 ])
218
- query = torch .randn (1 , 8 , 128 )
219
- key = torch .randn (1 , 8 , 128 )
245
+ with patch .object (self .layer , "_set_cos_sin_cache" ) as mock_set_cache :
246
+ mock_rope_forward_oot .return_value = (self .query , self .key )
220
247
221
- mock_rope_forward_oot .return_value = (query , key )
222
-
223
- q_pe , k_pe = native_rope_deepseek_forward (module ,
224
- positions ,
225
- query ,
226
- key ,
227
- max_seq_len = 2048 )
228
-
229
- assert q_pe .shape == query .shape
230
- assert k_pe .shape == key .shape
248
+ q_pe , k_pe = self .layer .forward (self .positions ,
249
+ self .query ,
250
+ self .key ,
251
+ max_seq_len = 2048 )
252
+ mock_set_cache .assert_called_once ()
253
+ assert q_pe .shape == self .query .shape
254
+ assert k_pe .shape == self .key .shape
231
255
232
256
@patch ('vllm_ascend.ops.rotary_embedding.rope_forward_oot' )
257
+ @patch ("vllm.model_executor.layers.rotary_embedding.current_platform" ,
258
+ new_callable = PropertyMock )
233
259
def test_native_rope_deepseek_forward_key_reshaping (
234
- self , mock_rope_forward_oot ):
235
- module = MockRopeModule ()
236
- positions = torch .tensor ([1 , 2 , 3 ])
237
- query = torch .randn (1 , 8 , 128 )
238
- key = torch .randn (1 , 128 )
260
+ self , mock_current_platform , mock_rope_forward_oot ):
261
+ mock_current_platform .device_type = torch .device ("cpu" )
262
+ self .layer = self ._create_layer ()
239
263
240
- mock_rope_forward_oot . return_value = ( query , key )
264
+ key = torch . randn ( 1 , 32 )
241
265
242
- q_pe , k_pe = native_rope_deepseek_forward (module , positions , query ,
243
- key )
266
+ mock_rope_forward_oot .return_value = (self .query , key )
244
267
245
- assert q_pe .shape == query .shape
246
- assert k_pe .shape == (1 , 128 )
268
+ q_pe , k_pe = self .layer .forward (self .positions , self .query , key )
269
+ mock_rope_forward_oot .assert_called_once ()
270
+ assert q_pe .shape == self .query .shape
271
+ assert k_pe .shape == key .shape
247
272
248
273
@patch ('vllm_ascend.ops.rotary_embedding.rope_forward_oot' )
274
+ @patch ("vllm.model_executor.layers.rotary_embedding.current_platform" ,
275
+ new_callable = PropertyMock )
249
276
def test_native_rope_deepseek_forward_non_neox_style (
250
- self , mock_rope_forward_oot ):
251
- module = MockRopeModule (is_neox_style = False )
252
- positions = torch .tensor ([1 , 2 , 3 ])
253
- query = torch .randn (1 , 8 , 128 )
254
- key = torch .randn (1 , 8 , 128 )
255
-
256
- mock_rope_forward_oot .return_value = (query , key )
277
+ self , mock_current_platform , mock_rope_forward_oot ):
278
+ mock_current_platform .device_type = torch .device ("cpu" )
279
+ self .layer = self ._create_layer ()
257
280
258
- q_pe , k_pe = native_rope_deepseek_forward (module , positions , query ,
259
- key )
260
-
261
- assert q_pe .shape == query .shape
262
- assert k_pe .shape == key .shape
281
+ mock_rope_forward_oot .return_value = (self .query , self .key )
263
282
283
+ q_pe , k_pe = self .layer .forward (self .positions , self .query , self .key )
264
284
265
- class TestRotateHalf (unittest .TestCase ):
266
-
267
- def test_rotate_half_even_dim (self ):
268
- # Test with even dimension
269
- x = torch .tensor ([1.0 , 2.0 , 3.0 , 4.0 ])
270
- expected = torch .tensor ([- 3.0 , - 4.0 , 1.0 , 2.0 ])
271
- result = rotate_half (x )
272
- self .assertTrue (torch .allclose (result , expected ))
285
+ mock_rope_forward_oot .assert_called_once ()
286
+ assert q_pe .shape == self .query .shape
287
+ assert k_pe .shape == self .key .shape
273
288
274
-
275
- class TestYarnFindCorrectionDim (unittest .TestCase ):
276
-
277
- def test_basic_case (self ):
289
+ @patch ("vllm.model_executor.layers.rotary_embedding.current_platform" ,
290
+ new_callable = PropertyMock )
291
+ def test_basic_case (self , mock_current_platform ):
278
292
# Test with standard values
293
+ mock_current_platform .device_type = torch .device ("cpu" )
294
+ self .layer = self ._create_layer ()
279
295
num_rotations = 100
280
296
dim = 512
281
297
base = 10000
282
298
max_position_embeddings = 2048
283
299
284
- result = yarn_find_correction_dim (num_rotations , dim , base ,
285
- max_position_embeddings )
300
+ result = self . layer . _yarn_find_correction_dim (num_rotations , dim , base ,
301
+ max_position_embeddings )
286
302
287
303
# Calculate expected value manually
288
304
expected = (dim * torch .log (
@@ -292,22 +308,25 @@ def test_basic_case(self):
292
308
293
309
self .assertTrue (torch .allclose (result , expected ))
294
310
311
+ @patch ("vllm.model_executor.layers.rotary_embedding.current_platform" ,
312
+ new_callable = PropertyMock )
313
+ def test_yarn_get_mscale (self , mock_current_platform ):
314
+ mock_current_platform .device_type = torch .device ("cpu" )
315
+ self .layer = self ._create_layer ()
295
316
296
- class TestYarnGetMscale (unittest .TestCase ):
297
-
298
- def test_scale_less_than_or_equal_1 (self ):
299
- self .assertEqual (yarn_get_mscale (scale = 0.5 ), 1.0 )
300
- self .assertEqual (yarn_get_mscale (scale = 1.0 ), 1.0 )
301
- self .assertEqual (yarn_get_mscale (scale = 0.999 ), 1.0 )
317
+ # test_scale_less_than_or_equal_1
318
+ self .assertEqual (self .layer ._yarn_get_mscale (scale = 0.5 ), 1.0 )
319
+ self .assertEqual (self .layer ._yarn_get_mscale (scale = 1.0 ), 1.0 )
320
+ self .assertEqual (self .layer ._yarn_get_mscale (scale = 0.999 ), 1.0 )
302
321
303
- def test_scale_greater_than_1 ( self ) :
322
+ # test_scale_greater_than_1:
304
323
test_cases = [(2.0 , 1.0 , 1.0 + 0.1 * math .log (2.0 )),
305
324
(10.0 , 1.0 , 1.0 + 0.1 * math .log (10.0 )),
306
325
(5.0 , 2.0 , 1.0 + 0.2 * math .log (5.0 )),
307
326
(math .e , 1.0 , 1.0 + 0.1 )]
308
327
309
328
for scale , mscale , expected in test_cases :
310
- result = yarn_get_mscale (scale , mscale )
329
+ result = self . layer . _yarn_get_mscale (scale , mscale )
311
330
self .assertAlmostEqual (
312
331
result ,
313
332
expected ,
0 commit comments