Skip to content

Commit 2204e4d

Browse files
authored
[0.7.3] Optimize apply_penalties & topKtopP for both V0/V1 Engine (#525)
This PR optimizes apply_penalties & topKtopP implementation in both V0/V1 Engine by avoiding using torch.scatter and matrix indexing operations. We verified the functionality of this PR using Qwen2.5-72B-Instruct. At a concurrency of 40 and with post-processing parameters set to "temperature": 0.3, "top_k": 100, "top_p": 0.9, "repetition_penalty": 1.01, the average decoding time was reduced from 300ms to 50ms. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent e1d13fc commit 2204e4d

File tree

8 files changed

+311
-1
lines changed

8 files changed

+311
-1
lines changed

vllm_ascend/sample/__init__.py

Whitespace-only changes.

vllm_ascend/sample/ops/__init__.py

Whitespace-only changes.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from typing import Dict, Optional
2+
3+
import torch
4+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
5+
6+
7+
class AscendTopKTopPSampler(TopKTopPSampler):
8+
9+
def forward_native(
10+
self,
11+
logits: torch.Tensor,
12+
generators: Dict[int, torch.Generator],
13+
k: Optional[torch.Tensor],
14+
p: Optional[torch.Tensor],
15+
) -> torch.Tensor:
16+
"""Optimized implementation of top-k and top-p sampling on NPU."""
17+
logits = apply_top_k_top_p_npu(logits, k, p)
18+
probs = logits.softmax(dim=-1, dtype=torch.float32)
19+
return random_sample(probs, generators)
20+
21+
22+
def apply_top_k_top_p_npu(
23+
logits: torch.Tensor,
24+
k: Optional[torch.Tensor],
25+
p: Optional[torch.Tensor],
26+
) -> torch.Tensor:
27+
"""Apply top-k and/or top-p optimized for NPU."""
28+
if k is None and p is None:
29+
return logits
30+
31+
batch_size, vocab_size = logits.shape
32+
device = logits.device
33+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
34+
if k is not None:
35+
safe_k = torch.clamp(k, min=1, max=vocab_size)
36+
boundary_idx = (vocab_size - safe_k).unsqueeze(1)
37+
boundary = logits_sort.gather(1, boundary_idx)
38+
top_k_mask = logits_sort < boundary
39+
logits_sort = logits_sort.masked_fill(top_k_mask, -float("inf"))
40+
else:
41+
top_k_mask = torch.zeros_like(logits_sort, dtype=torch.bool)
42+
43+
cutoffs = top_k_mask.sum(dim=-1)
44+
strides = torch.arange(0,
45+
batch_size * vocab_size,
46+
vocab_size,
47+
device=device).unsqueeze(1)
48+
if p is not None:
49+
global_cutoff = cutoffs.min()
50+
active_part = logits_idx[:, global_cutoff:]
51+
probs_sort = logits_sort[:, global_cutoff:].softmax(dim=-1)
52+
cumprob = probs_sort.cumsum(dim=-1)
53+
top_p_mask = (cumprob <= (1 - p.unsqueeze(1))) | (torch.arange(
54+
probs_sort.size(1), device=device) == probs_sort.size(1) - 1)
55+
else:
56+
active_part = logits_idx
57+
top_p_mask = torch.arange(vocab_size, device=device).expand(
58+
batch_size, -1) >= cutoffs.unsqueeze(1)
59+
60+
valid_idx = (active_part + strides).masked_select(top_p_mask)
61+
logits_flatten = logits.flatten()
62+
output = torch.full_like(logits_flatten, -float('inf'))
63+
output[valid_idx] = logits_flatten[valid_idx]
64+
return output.reshape(batch_size, vocab_size)

vllm_ascend/sample/ops/penalties.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import torch
4+
from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask
5+
from vllm.v1.sample.ops.penalties import _convert_to_tensors
6+
7+
8+
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
9+
output_tokens_tensor: torch.Tensor,
10+
presence_penalties: torch.Tensor,
11+
frequency_penalties: torch.Tensor,
12+
repetition_penalties: torch.Tensor) -> torch.Tensor:
13+
"""Optimized implementation of repetition penalties on NPU.
14+
15+
Applies penalties in place to the logits tensor
16+
logits : The input logits tensor of shape [num_seqs, vocab_size]
17+
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
18+
are padded to the maximum prompt length within the batch using
19+
`vocab_size` as the padding value. The value `vocab_size` is used
20+
for padding because it does not correspond to any valid token ID
21+
in the vocabulary.
22+
output_tokens_tensor: The output tokens tensor.
23+
presence_penalties: The presence penalties of shape (num_seqs, )
24+
frequency_penalties: The frequency penalties of shape (num_seqs, )
25+
repetition_penalties: The repetition penalties of shape (num_seqs, )
26+
"""
27+
num_seqs, vocab_size = logits.shape
28+
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
29+
vocab_size, num_seqs)
30+
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
31+
output_tokens_tensor, vocab_size, num_seqs)
32+
33+
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
34+
1, vocab_size)
35+
36+
# Avoid IndexPut operations in original apply_penalties function which are extremely time-consuming on NPU.
37+
sequence_mask = prompt_mask | output_mask
38+
logits = torch.where(sequence_mask & torch.lt(logits, 0),
39+
logits * repetition_penalties,
40+
logits).to(logits.dtype)
41+
logits = torch.where(sequence_mask & torch.ge(logits, 0),
42+
logits / repetition_penalties,
43+
logits).to(logits.dtype)
44+
45+
# We follow the definition in OpenAI API.
46+
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
47+
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
48+
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
49+
return logits
50+
51+
52+
def apply_all_penalties(
53+
logits: torch.Tensor,
54+
prompt_token_ids: torch.Tensor,
55+
presence_penalties: torch.Tensor,
56+
frequency_penalties: torch.Tensor,
57+
repetition_penalties: torch.Tensor,
58+
output_token_ids: list[list[int]],
59+
) -> torch.Tensor:
60+
"""
61+
Applies presence, frequency and repetition penalties to the logits.
62+
"""
63+
_, vocab_size = logits.shape
64+
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
65+
logits.device)
66+
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
67+
presence_penalties, frequency_penalties,
68+
repetition_penalties)

vllm_ascend/sample/sampler.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""A layer that samples the next tokens from the model's outputs."""
3+
from typing import Optional
4+
5+
import torch
6+
from vllm.model_executor.layers.sampler import (Sampler, SampleResultArgsType,
7+
SamplerOutput, _apply_min_p,
8+
_apply_min_tokens_penalty,
9+
_build_sampler_output, _sample,
10+
get_logprobs)
11+
from vllm.model_executor.sampling_metadata import SamplingMetadata
12+
13+
from vllm_ascend.sample.ops.penalties import apply_penalties
14+
15+
16+
class AscendSampler(Sampler):
17+
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(
22+
self,
23+
logits: torch.Tensor,
24+
sampling_metadata: SamplingMetadata,
25+
) -> Optional[SamplerOutput]:
26+
assert logits is not None
27+
_, vocab_size = logits.shape
28+
29+
# Prepare sampling tensors with pinned memory to avoid blocking.
30+
if not sampling_metadata.reuse_sampling_tensors:
31+
self._init_sampling_tensors(logits, sampling_metadata)
32+
elif self._do_penalties:
33+
# In this case, the sampling tensors logic depends on
34+
# "output_tokens" of a sequence. As a result, we cannot
35+
# reuse sampling tensors, since "output_tokens" changes
36+
# between decode runs.
37+
self._init_sampling_tensors(logits, sampling_metadata)
38+
39+
assert self._sampling_tensors is not None
40+
sampling_tensors = self._sampling_tensors
41+
do_penalties = self._do_penalties
42+
do_top_p_top_k = self._do_top_p_top_k
43+
do_min_p = self._do_min_p
44+
45+
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
46+
47+
# Apply presence and frequency penalties.
48+
if do_penalties:
49+
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
50+
sampling_tensors.output_tokens,
51+
sampling_tensors.presence_penalties,
52+
sampling_tensors.frequency_penalties,
53+
sampling_tensors.repetition_penalties)
54+
55+
# Use float32 to apply temperature scaling.
56+
# Use in-place division to avoid creating a new tensor.
57+
logits = logits.to(torch.float)
58+
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
59+
60+
if do_top_p_top_k:
61+
logits = _apply_top_k_top_p_npu(logits, sampling_tensors.top_ps,
62+
sampling_tensors.top_ks)
63+
64+
if do_min_p:
65+
logits = _apply_min_p(logits, sampling_tensors.min_ps)
66+
67+
# We use float32 for probabilities and log probabilities.
68+
# Compute the probabilities.
69+
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
70+
# Compute the log probabilities.
71+
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
72+
73+
# Sample the next tokens.
74+
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
75+
probs,
76+
logprobs,
77+
sampling_metadata,
78+
sampling_tensors,
79+
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
80+
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
81+
)
82+
83+
if self.include_gpu_probs_tensor:
84+
assert maybe_sampled_tokens_tensor is not None
85+
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
86+
else:
87+
on_device_tensors = None
88+
89+
# Get the logprobs query results.
90+
prompt_logprobs = None
91+
sample_logprobs = None
92+
if not sampling_metadata.skip_sampler_cpu_output:
93+
assert not isinstance(maybe_deferred_sample_results,
94+
SampleResultArgsType)
95+
prompt_logprobs, sample_logprobs = get_logprobs(
96+
logprobs, sampling_metadata, maybe_deferred_sample_results)
97+
98+
return _build_sampler_output(
99+
maybe_deferred_sample_results,
100+
sampling_metadata,
101+
prompt_logprobs,
102+
sample_logprobs,
103+
on_device_tensors=on_device_tensors,
104+
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
105+
106+
107+
def _apply_top_k_top_p_npu(
108+
logits: torch.Tensor,
109+
p: torch.Tensor,
110+
k: torch.Tensor,
111+
) -> torch.Tensor:
112+
"""Apply top-k and top-p optimized for NPU.
113+
114+
This algorithm avoids using torch.scatter which is time-consuming on NPU.
115+
"""
116+
batch_size, vocab_size = logits.shape
117+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
118+
119+
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
120+
top_k_mask = logits_sort < boundary
121+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
122+
cutoff = top_k_mask.sum(dim=-1).min()
123+
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
124+
probs_sum = probs_sort.cumsum(dim=-1)
125+
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
126+
top_p_mask[:, -1] = True
127+
strides = torch.arange(0,
128+
batch_size * vocab_size,
129+
vocab_size,
130+
device=logits.device)
131+
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
132+
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
133+
logits_flatten = logits.flatten()
134+
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
135+
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
136+
logits[valid_idx] = valid_logits
137+
return logits.reshape(batch_size, vocab_size)

vllm_ascend/sample/sampler_v1.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
from vllm.v1.sample.metadata import SamplingMetadata
3+
from vllm.v1.sample.ops.penalties import apply_min_token_penalties
4+
from vllm.v1.sample.sampler import Sampler
5+
6+
from vllm_ascend.sample.ops.ascend_topk_topp_sampler import \
7+
AscendTopKTopPSampler
8+
from vllm_ascend.sample.ops.penalties import apply_all_penalties
9+
10+
11+
class AscendSampler(Sampler):
12+
13+
def __init__(self):
14+
super().__init__()
15+
self.topk_topp_sampler = AscendTopKTopPSampler()
16+
17+
def apply_penalties(
18+
self,
19+
logits: torch.Tensor,
20+
sampling_metadata: SamplingMetadata,
21+
) -> torch.Tensor:
22+
if sampling_metadata.min_tokens:
23+
apply_min_token_penalties(logits,
24+
sampling_metadata.output_token_ids,
25+
sampling_metadata.min_tokens)
26+
if not sampling_metadata.no_penalties:
27+
assert sampling_metadata.prompt_token_ids is not None
28+
logits = apply_all_penalties(
29+
logits,
30+
sampling_metadata.prompt_token_ids,
31+
sampling_metadata.presence_penalties,
32+
sampling_metadata.frequency_penalties,
33+
sampling_metadata.repetition_penalties,
34+
sampling_metadata.output_token_ids,
35+
)
36+
return logits

vllm_ascend/worker/model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
_init_attn_metadata_from_tensor_dict,
6262
_init_sampling_metadata_from_tensor_dict)
6363

64+
from vllm_ascend.sample.sampler import AscendSampler
65+
6466
if TYPE_CHECKING:
6567
from vllm.attention.backends.abstract import AttentionBackend
6668

@@ -820,7 +822,7 @@ def load_model(self) -> None:
820822
logger.info("Starting to load model %s...", self.model_config.model)
821823
with DeviceMemoryProfiler() as m:
822824
self.model = get_model(vllm_config=self.vllm_config)
823-
825+
self.model.sampler = AscendSampler()
824826
self.model_memory_usage = m.consumed_memory
825827
logger.info("Loading model weights took %.4f GB",
826828
self.model_memory_usage / float(2**30))

vllm_ascend/worker/model_runner_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from vllm_ascend.attention.attention import AttentionMaskBuilder
5353
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
5454
AscendMetadata)
55+
from vllm_ascend.sample.sampler_v1 import AscendSampler
5556

5657
if TYPE_CHECKING:
5758
from vllm.v1.core.scheduler_output import SchedulerOutput
@@ -810,6 +811,8 @@ def load_model(self) -> None:
810811

811812
with DeviceMemoryProfiler() as m: # noqa: SIM117
812813
self.model = get_model(vllm_config=self.vllm_config)
814+
self.model.sampler = AscendSampler()
815+
813816
if self.lora_config:
814817
raise ValueError("LoRA model is not supported on NPU now.")
815818

0 commit comments

Comments
 (0)