Skip to content

Commit cc1588b

Browse files
authored
[Misc] Code clean up (#1674)
Remove useless function - vLLM version: v0.9.2 - vLLM main: vllm-project/vllm@b942c09 Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 830332e commit cc1588b

File tree

2 files changed

+1
-81
lines changed

2 files changed

+1
-81
lines changed

vllm_ascend/worker/eagle_proposer_v1.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -384,46 +384,3 @@ def prepare_eagle_input_sequential(out_tensor: torch.Tensor,
384384
(target_indices < end_pos) & \
385385
(offset_tensor < num_tokens)
386386
out_tensor[target_indices[mask]] = values_to_store[mask]
387-
388-
389-
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
390-
# to sample the draft tokens. We will use this after we find a way to manage
391-
# the draft prob tensor.
392-
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
393-
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
394-
# We should refactor this to reuse the same sampling implementation.
395-
def compute_probs_and_sample_next_token(
396-
logits: torch.Tensor,
397-
sampling_metadata: SamplingMetadata,
398-
) -> tuple[torch.Tensor, torch.Tensor]:
399-
if sampling_metadata.all_greedy:
400-
# For greedy requests, draft_probs is not used in rejection sampling.
401-
# Therefore, we can just return the logits.
402-
probs = logits
403-
next_token_ids = logits.argmax(dim=-1)
404-
return next_token_ids, probs
405-
406-
is_greedy = sampling_metadata.temperature == -1
407-
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
408-
logits.div_(temperature.view(-1, 1))
409-
probs = logits.softmax(dim=-1, dtype=torch.float32)
410-
411-
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
412-
# generating the draft tokens. We only use the temperature. While this
413-
# could degrade the acceptance rate, it does not affect the distribution
414-
# of the generated tokens after rejection sampling.
415-
416-
# TODO(woosuk): Consider seeds.
417-
q = torch.empty_like(probs)
418-
q.exponential_()
419-
# NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
420-
# will be used later for rejection sampling.
421-
next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
422-
if not sampling_metadata.all_random:
423-
greedy_token_ids = probs.argmax(dim=-1)
424-
next_token_ids = torch.where(
425-
is_greedy,
426-
greedy_token_ids,
427-
next_token_ids,
428-
)
429-
return next_token_ids, probs

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,43 +12,6 @@
1212
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
1313

1414

15-
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
16-
# We should refactor this to reuse the same sampling implementation.
17-
def compute_probs_and_sample_next_token(
18-
logits: torch.Tensor,
19-
sampling_metadata: SamplingMetadata,
20-
) -> tuple[torch.Tensor, torch.Tensor]:
21-
if sampling_metadata.all_greedy:
22-
# For greedy requests, draft_probs is not used in rejection sampling.
23-
# Therefore, we can just return the logits.
24-
probs = logits
25-
next_token_ids = logits.argmax(dim=-1)
26-
return next_token_ids, probs
27-
28-
is_greedy = sampling_metadata.temperature == -1
29-
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
30-
logits.div_(temperature.view(-1, 1))
31-
probs = logits.softmax(dim=-1, dtype=torch.float32)
32-
33-
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
34-
# generating the draft tokens. We only use the temperature. While this
35-
# could degrade the acceptance rate, it does not affect the distribution
36-
# of the generated tokens after rejection sampling.
37-
38-
# TODO(woosuk): Consider seeds.
39-
q = torch.empty_like(probs)
40-
q.exponential_()
41-
next_token_ids = probs.div_(q).argmax(dim=-1).view(-1)
42-
if not sampling_metadata.all_random:
43-
greedy_token_ids = probs.argmax(dim=-1)
44-
next_token_ids = torch.where(
45-
is_greedy,
46-
greedy_token_ids,
47-
next_token_ids,
48-
)
49-
return next_token_ids, probs
50-
51-
5215
class MtpProposer:
5316

5417
def __init__(
@@ -121,7 +84,7 @@ def propose(
12184
# [batch_size, max_num_blocks_per_req]
12285
block_table: torch.Tensor,
12386
sampling_metadata: SamplingMetadata,
124-
) -> tuple[torch.Tensor, torch.Tensor]:
87+
) -> torch.Tensor:
12588
num_tokens = target_token_ids.shape[0]
12689
batch_size = next_token_ids.shape[0]
12790
last_token_indices = cu_num_tokens[1:] - 1

0 commit comments

Comments
 (0)