Skip to content

Commit 578d238

Browse files
committed
ut: add ut for npu_mrope
Signed-off-by: David9857 <985700846@qq.com>
1 parent defaf22 commit 578d238

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

tests/singlecard/ops/test_rotary_embedding.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import torch
1111
import torch.nn as nn
12+
import torch_npu
1213

1314
import vllm_ascend.platform # noqa: F401
1415

@@ -196,3 +197,68 @@ def test_rotary_embedding_quant_with_leading_dim(
196197
ref_key,
197198
atol=DEFAULT_ATOL,
198199
rtol=DEFAULT_RTOL)
200+
201+
# test rope with npu_mrope interface with leading dimension and merge seqlen and batch_size as num_tokens
202+
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
203+
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
204+
@pytest.mark.parametrize("seq_len", SEQ_LENS)
205+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
206+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
207+
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
208+
@pytest.mark.parametrize("dtype", DTYPES)
209+
@pytest.mark.parametrize("seed", SEEDS)
210+
@pytest.mark.parametrize("device", DEVICES)
211+
@torch.inference_mode()
212+
def test_npu_mrope_quant_with_leading_dim(
213+
is_neox_style: bool,
214+
batch_size: int,
215+
seq_len: int,
216+
num_heads: int,
217+
head_size: int,
218+
rotary_dim: Optional[int],
219+
dtype: torch.dtype,
220+
seed: int,
221+
device: str,
222+
max_position: int = 8192,
223+
base: int = 10000,
224+
) -> None:
225+
if rotary_dim is None:
226+
rotary_dim = head_size
227+
228+
torch.set_default_device(device)
229+
if rotary_dim is None:
230+
rotary_dim = head_size
231+
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
232+
is_neox_style, dtype)
233+
rope = rope.to(dtype=dtype)
234+
num_tokens = batch_size * seq_len
235+
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
236+
qkv_tensor = torch.randn(num_tokens,
237+
num_heads * head_size * 3,
238+
dtype=dtype)
239+
query, key, _ = qkv_tensor.split(
240+
[num_heads * head_size, num_heads * head_size, num_heads * head_size],
241+
dim=-1,
242+
)
243+
244+
ref_query, ref_key = rope.forward_native(positions, query, key)
245+
246+
query, key = torch_npu.npu_mrope(
247+
positions,
248+
query,
249+
key,
250+
rope.cos_sin_cache,
251+
rope.head_size,
252+
mrope_section=[0,0,0],
253+
rotary_mode='half' if rope.is_neox_style else 'interleave'
254+
)
255+
256+
# Compare the results.
257+
torch.testing.assert_close(query.view(ref_query.size()),
258+
ref_query,
259+
atol=DEFAULT_ATOL,
260+
rtol=DEFAULT_RTOL)
261+
torch.testing.assert_close(key.view(ref_key.size()),
262+
ref_key,
263+
atol=DEFAULT_ATOL,
264+
rtol=DEFAULT_RTOL)

0 commit comments

Comments
 (0)