Skip to content

[WIP][perf] Replace _npu_rotary_embedding with npu_mrope #1195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions tests/singlecard/ops/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
import torch
import torch.nn as nn
import torch_npu

import vllm_ascend.platform # noqa: F401

Expand Down Expand Up @@ -196,3 +197,68 @@ def test_rotary_embedding_quant_with_leading_dim(
ref_key,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)


# test rope with npu_mrope interface with leading dimension and merge seqlen and batch_size as num_tokens
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_npu_mrope_quant_with_leading_dim(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
if rotary_dim is None:
rotary_dim = head_size

torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
rope = rope.to(dtype=dtype)
num_tokens = batch_size * seq_len
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
qkv_tensor = torch.randn(num_tokens,
num_heads * head_size * 3,
dtype=dtype)
query, key, _ = qkv_tensor.split(
[num_heads * head_size, num_heads * head_size, num_heads * head_size],
dim=-1,
)

ref_query, ref_key = rope.forward_native(positions, query, key)

query, key = torch_npu.npu_mrope(
positions,
query,
key,
rope.cos_sin_cache,
rope.head_size,
mrope_section=[0, 0, 0],
rotary_mode='half' if rope.is_neox_style else 'interleave')

# Compare the results.
torch.testing.assert_close(query.view(ref_query.size()),
ref_query,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(key.view(ref_key.size()),
ref_key,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
8 changes: 4 additions & 4 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ def rope_forward_oot(
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
query, key = torch_npu.npu_mrope(
Copy link
Collaborator

@Yikun Yikun Jun 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick question: which torch_npu version supports the npu_mrope operator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for graph mode, no release version supports npu_mrope yet

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you combine #1231 together

positions,
query,
key,
self.head_size,
self.cos_sin_cache,
neox_style,
)
self.head_size,
mrope_section=[0, 0, 0],
rotary_mode='half' if neox_style else 'interleave')
return query.view(query_shape), key.view(key_shape)


Expand Down
Loading