@@ -221,8 +221,13 @@ def _create_layer(self):
221
221
222
222
@patch ("vllm.model_executor.layers.rotary_embedding.current_platform" ,
223
223
new_callable = PropertyMock )
224
- def test_native_rope_deepseek_forward_base (self , mock_current_platform ):
224
+ @patch ("vllm_ascend.ops.rotary_embedding.current_platform" ,
225
+ new_callable = PropertyMock )
226
+ def test_native_rope_deepseek_forward_base (self ,
227
+ mock_current_platform_ascend ,
228
+ mock_current_platform ):
225
229
mock_current_platform .device_type = torch .device ("cpu" )
230
+ mock_current_platform_ascend .device_type = torch .device ("cpu" )
226
231
self .layer = self ._create_layer ()
227
232
with patch ("vllm_ascend.ops.rotary_embedding.rope_forward_oot" ,
228
233
return_value = (self .query ,
@@ -236,9 +241,13 @@ def test_native_rope_deepseek_forward_base(self, mock_current_platform):
236
241
@patch ('vllm_ascend.ops.rotary_embedding.rope_forward_oot' )
237
242
@patch ("vllm.model_executor.layers.rotary_embedding.current_platform" ,
238
243
new_callable = PropertyMock )
244
+ @patch ("vllm_ascend.ops.rotary_embedding.current_platform" ,
245
+ new_callable = PropertyMock )
239
246
def test_native_rope_deepseek_forward_cache_handling (
240
- self , mock_current_platform , mock_rope_forward_oot ):
247
+ self , mock_current_platform_ascend , mock_current_platform ,
248
+ mock_rope_forward_oot ):
241
249
mock_current_platform .device_type = torch .device ("cpu" )
250
+ mock_current_platform_ascend .device_type = torch .device ("cpu" )
242
251
self .layer = self ._create_layer ()
243
252
self .layer .max_seq_len = 1024
244
253
# Test cache situation is true
@@ -256,9 +265,13 @@ def test_native_rope_deepseek_forward_cache_handling(
256
265
@patch ('vllm_ascend.ops.rotary_embedding.rope_forward_oot' )
257
266
@patch ("vllm.model_executor.layers.rotary_embedding.current_platform" ,
258
267
new_callable = PropertyMock )
268
+ @patch ("vllm_ascend.ops.rotary_embedding.current_platform" ,
269
+ new_callable = PropertyMock )
259
270
def test_native_rope_deepseek_forward_key_reshaping (
260
- self , mock_current_platform , mock_rope_forward_oot ):
271
+ self , mock_current_platform_ascend , mock_current_platform ,
272
+ mock_rope_forward_oot ):
261
273
mock_current_platform .device_type = torch .device ("cpu" )
274
+ mock_current_platform_ascend .device_type = torch .device ("cpu" )
262
275
self .layer = self ._create_layer ()
263
276
264
277
key = torch .randn (1 , 32 )
@@ -273,9 +286,13 @@ def test_native_rope_deepseek_forward_key_reshaping(
273
286
@patch ('vllm_ascend.ops.rotary_embedding.rope_forward_oot' )
274
287
@patch ("vllm.model_executor.layers.rotary_embedding.current_platform" ,
275
288
new_callable = PropertyMock )
289
+ @patch ("vllm_ascend.ops.rotary_embedding.current_platform" ,
290
+ new_callable = PropertyMock )
276
291
def test_native_rope_deepseek_forward_non_neox_style (
277
- self , mock_current_platform , mock_rope_forward_oot ):
292
+ self , mock_current_platform_ascend , mock_current_platform ,
293
+ mock_rope_forward_oot ):
278
294
mock_current_platform .device_type = torch .device ("cpu" )
295
+ mock_current_platform_ascend .device_type = torch .device ("cpu" )
279
296
self .layer = self ._create_layer ()
280
297
281
298
mock_rope_forward_oot .return_value = (self .query , self .key )
@@ -288,9 +305,13 @@ def test_native_rope_deepseek_forward_non_neox_style(
288
305
289
306
@patch ("vllm.model_executor.layers.rotary_embedding.current_platform" ,
290
307
new_callable = PropertyMock )
291
- def test_basic_case (self , mock_current_platform ):
308
+ @patch ("vllm_ascend.ops.rotary_embedding.current_platform" ,
309
+ new_callable = PropertyMock )
310
+ def test_basic_case (self , mock_current_platform_ascend ,
311
+ mock_current_platform ):
292
312
# Test with standard values
293
313
mock_current_platform .device_type = torch .device ("cpu" )
314
+ mock_current_platform_ascend .device_type = torch .device ("cpu" )
294
315
self .layer = self ._create_layer ()
295
316
num_rotations = 100
296
317
dim = 512
@@ -310,8 +331,12 @@ def test_basic_case(self, mock_current_platform):
310
331
311
332
@patch ("vllm.model_executor.layers.rotary_embedding.current_platform" ,
312
333
new_callable = PropertyMock )
313
- def test_yarn_get_mscale (self , mock_current_platform ):
334
+ @patch ("vllm_ascend.ops.rotary_embedding.current_platform" ,
335
+ new_callable = PropertyMock )
336
+ def test_yarn_get_mscale (self , mock_current_platform_ascend ,
337
+ mock_current_platform ):
314
338
mock_current_platform .device_type = torch .device ("cpu" )
339
+ mock_current_platform_ascend .device_type = torch .device ("cpu" )
315
340
self .layer = self ._create_layer ()
316
341
317
342
# test_scale_less_than_or_equal_1
0 commit comments