diff --git a/tests/singlecard/ops/test_rotary_embedding.py b/tests/singlecard/ops/test_rotary_embedding.py index 2d5ec18daf..a3e8d7f657 100644 --- a/tests/singlecard/ops/test_rotary_embedding.py +++ b/tests/singlecard/ops/test_rotary_embedding.py @@ -9,6 +9,7 @@ import pytest import torch import torch.nn as nn +import torch_npu import vllm_ascend.platform # noqa: F401 @@ -196,3 +197,68 @@ def test_rotary_embedding_quant_with_leading_dim( ref_key, atol=DEFAULT_ATOL, rtol=DEFAULT_RTOL) + + +# test rope with npu_mrope interface with leading dimension and merge seqlen and batch_size as num_tokens +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_npu_mrope_quant_with_leading_dim( + is_neox_style: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + seed: int, + device: str, + max_position: int = 8192, + base: int = 10000, +) -> None: + if rotary_dim is None: + rotary_dim = head_size + + torch.set_default_device(device) + if rotary_dim is None: + rotary_dim = head_size + rope = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, dtype) + rope = rope.to(dtype=dtype) + num_tokens = batch_size * seq_len + positions = torch.randint(0, max_position, (batch_size * seq_len, )) + qkv_tensor = torch.randn(num_tokens, + num_heads * head_size * 3, + dtype=dtype) + query, key, _ = qkv_tensor.split( + [num_heads * head_size, num_heads * head_size, num_heads * head_size], + dim=-1, + ) + + ref_query, ref_key = rope.forward_native(positions, query, key) + + query, key = torch_npu.npu_mrope( + positions, + query, + key, + rope.cos_sin_cache, + rope.head_size, + mrope_section=[0, 0, 0], + rotary_mode='half' if rope.is_neox_style else 'interleave') + + # Compare the results. + torch.testing.assert_close(query.view(ref_query.size()), + ref_query, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + torch.testing.assert_close(key.view(ref_key.size()), + ref_key, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 0c2a00afb6..81dde8b8bf 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -64,14 +64,14 @@ def rope_forward_oot( # TODO: Remove the contiguous in the future. query = query.contiguous().view(query.shape[0], -1) key = key.contiguous().view(key.shape[0], -1) - torch_npu._npu_rotary_embedding( + query, key = torch_npu.npu_mrope( positions, query, key, - self.head_size, self.cos_sin_cache, - neox_style, - ) + self.head_size, + mrope_section=[0, 0, 0], + rotary_mode='half' if neox_style else 'interleave') return query.view(query_shape), key.view(key_shape)