9
9
import pytest
10
10
import torch
11
11
import torch .nn as nn
12
+ import torch_npu
12
13
13
14
import vllm_ascend .platform # noqa: F401
14
15
@@ -196,3 +197,68 @@ def test_rotary_embedding_quant_with_leading_dim(
196
197
ref_key ,
197
198
atol = DEFAULT_ATOL ,
198
199
rtol = DEFAULT_RTOL )
200
+
201
+ # test rope with npu_mrope interface with leading dimension and merge seqlen and batch_size as num_tokens
202
+ @pytest .mark .parametrize ("is_neox_style" , IS_NEOX_STYLE )
203
+ @pytest .mark .parametrize ("batch_size" , BATCH_SIZES )
204
+ @pytest .mark .parametrize ("seq_len" , SEQ_LENS )
205
+ @pytest .mark .parametrize ("num_heads" , NUM_HEADS )
206
+ @pytest .mark .parametrize ("head_size" , HEAD_SIZES )
207
+ @pytest .mark .parametrize ("rotary_dim" , ROTARY_DIMS )
208
+ @pytest .mark .parametrize ("dtype" , DTYPES )
209
+ @pytest .mark .parametrize ("seed" , SEEDS )
210
+ @pytest .mark .parametrize ("device" , DEVICES )
211
+ @torch .inference_mode ()
212
+ def test_npu_mrope_quant_with_leading_dim (
213
+ is_neox_style : bool ,
214
+ batch_size : int ,
215
+ seq_len : int ,
216
+ num_heads : int ,
217
+ head_size : int ,
218
+ rotary_dim : Optional [int ],
219
+ dtype : torch .dtype ,
220
+ seed : int ,
221
+ device : str ,
222
+ max_position : int = 8192 ,
223
+ base : int = 10000 ,
224
+ ) -> None :
225
+ if rotary_dim is None :
226
+ rotary_dim = head_size
227
+
228
+ torch .set_default_device (device )
229
+ if rotary_dim is None :
230
+ rotary_dim = head_size
231
+ rope = RotaryEmbedding (head_size , rotary_dim , max_position , base ,
232
+ is_neox_style , dtype )
233
+ rope = rope .to (dtype = dtype )
234
+ num_tokens = batch_size * seq_len
235
+ positions = torch .randint (0 , max_position , (batch_size * seq_len , ))
236
+ qkv_tensor = torch .randn (num_tokens ,
237
+ num_heads * head_size * 3 ,
238
+ dtype = dtype )
239
+ query , key , _ = qkv_tensor .split (
240
+ [num_heads * head_size , num_heads * head_size , num_heads * head_size ],
241
+ dim = - 1 ,
242
+ )
243
+
244
+ ref_query , ref_key = rope .forward_native (positions , query , key )
245
+
246
+ query , key = torch_npu .npu_mrope (
247
+ positions ,
248
+ query ,
249
+ key ,
250
+ rope .cos_sin_cache ,
251
+ rope .head_size ,
252
+ mrope_section = [0 ,0 ,0 ],
253
+ rotary_mode = 'half' if rope .is_neox_style else 'interleave'
254
+ )
255
+
256
+ # Compare the results.
257
+ torch .testing .assert_close (query .view (ref_query .size ()),
258
+ ref_query ,
259
+ atol = DEFAULT_ATOL ,
260
+ rtol = DEFAULT_RTOL )
261
+ torch .testing .assert_close (key .view (ref_key .size ()),
262
+ ref_key ,
263
+ atol = DEFAULT_ATOL ,
264
+ rtol = DEFAULT_RTOL )
0 commit comments