Skip to content

Commit 2fda604

Browse files
[Perf] Use fused ops npu_top_k_top_p (#1308)
### What this PR does / why we need it? Use fused ops torch_npu.npu_top_k_top_p(logits, p, k) when p and k are not None, otherwise fallback to the original one. The replacement will take place automatically when `VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE=1` . This patch are using `npu_top_k_top_p` which required torch_npu>=2.5.1.post1.dev20250619 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Tested by DeepSeek R1 and UT passed Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
1 parent e7efc7e commit 2fda604

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import importlib
2+
import os
3+
import unittest
4+
from unittest import mock
5+
6+
import torch
7+
from vllm.v1.sample.ops import topk_topp_sampler
8+
9+
10+
class TestTopKTopPSamplerOptimize(unittest.TestCase):
11+
12+
@mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1"})
13+
@mock.patch("torch_npu.npu_top_k_top_p")
14+
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
15+
import vllm_ascend.patch.worker.patch_common.patch_sampler
16+
importlib.reload(vllm_ascend.patch.worker.patch_common.patch_sampler)
17+
18+
mock_npu_op.return_value = (torch.randn(1, 3))
19+
sampler = topk_topp_sampler.TopKTopPSampler()
20+
21+
logits = torch.tensor([[1.0, 2.0, 3.0]])
22+
k = torch.tensor([2])
23+
p = torch.tensor([0.9])
24+
generators = {0: torch.Generator()}
25+
generators[0].manual_seed(42)
26+
27+
sampler.forward_native(logits, generators, k, p)
28+
mock_npu_op.assert_called_once_with(logits, p, k)

vllm_ascend/patch/worker/patch_common/patch_sampler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Optional
2020

2121
import torch
22+
import torch_npu
2223
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
2324
from vllm.v1.sample.sampler import Sampler
2425

@@ -48,9 +49,13 @@ def apply_min_p(
4849

4950
def _apply_top_k_top_p(
5051
logits: torch.Tensor,
51-
p: torch.Tensor,
5252
k: torch.Tensor,
53+
p: torch.Tensor,
5354
) -> torch.Tensor:
55+
if p is not None and k is not None:
56+
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
57+
return torch_npu.npu_top_k_top_p(logits, p, k)
58+
5459
probs = logits.softmax(dim=-1)
5560
probs_sort, _ = probs.sort(dim=-1, descending=False)
5661

0 commit comments

Comments
 (0)