Skip to content

Commit 641a4e6

Browse files
authored
[CI] Cache sampled token ids in model runner to fix CI error (#1573)
### What this PR does / why we need it? vllm change vllm-project/vllm@7f280d6 break vllm-ascend. This PR Fix the broken CI ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? passed Closes: #1572 Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 0e43813 commit 641a4e6

File tree

2 files changed

+57
-29
lines changed

2 files changed

+57
-29
lines changed

tests/e2e/singlecard/test_ilama_lora.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
32
import vllm
43
from vllm.lora.request import LoRARequest
54

vllm_ascend/worker/model_runner_v1.py

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -527,24 +527,27 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
527527
self.input_batch.num_tokens[req_index] = end_token_index
528528
else:
529529
req_data = scheduler_output.scheduled_cached_reqs
530+
is_last_rank = get_pp_group().is_last_rank
530531
for i, req_id in enumerate(req_data.req_ids):
531532
req_state = self.requests[req_id]
532533
num_computed_tokens = req_data.num_computed_tokens[i]
533-
new_token_ids = req_data.new_token_ids[i]
534534
new_block_ids = req_data.new_block_ids[i]
535535
resumed_from_preemption = req_data.resumed_from_preemption[i]
536536

537537
req_state.num_computed_tokens = num_computed_tokens
538-
# Add the sampled token(s) from the previous step (if any).
539-
# This doesn't include "unverified" tokens like spec decode tokens.
540-
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
541-
req_state.num_tokens)
542-
if num_new_tokens == 1:
543-
# Avoid slicing list in most common case.
544-
req_state.output_token_ids.append(new_token_ids[-1])
545-
elif num_new_tokens > 0:
546-
req_state.output_token_ids.extend(
547-
new_token_ids[-num_new_tokens:])
538+
if not is_last_rank:
539+
new_token_ids = req_data.new_token_ids[i]
540+
# Add the sampled token(s) from the previous step (if any).
541+
# This doesn't include "unverified" tokens like spec decode tokens.
542+
num_new_tokens = (num_computed_tokens +
543+
len(new_token_ids) -
544+
req_state.num_tokens)
545+
if num_new_tokens == 1:
546+
# Avoid slicing list in most common case.
547+
req_state.output_token_ids.append(new_token_ids[-1])
548+
elif num_new_tokens > 0:
549+
req_state.output_token_ids.extend(
550+
new_token_ids[-num_new_tokens:])
548551
# Update the block IDs.
549552
if not resumed_from_preemption:
550553
# Append the new blocks to the existing block IDs.
@@ -570,25 +573,27 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
570573

571574
self.input_batch.block_table.append_row(
572575
new_block_ids, req_index)
573-
# Add new_token_ids to token_ids_cpu.
574-
start_token_index = num_computed_tokens
575-
end_token_index = num_computed_tokens + len(new_token_ids)
576-
self.input_batch.token_ids_cpu[
577-
req_index,
578-
start_token_index:end_token_index] = new_token_ids
579-
self.input_batch.num_tokens_no_spec[
580-
req_index] = end_token_index
581-
# Add spec_token_ids to token_ids_cpu.
582-
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
583-
req_id, ())
584-
if spec_token_ids:
585-
start_index = end_token_index
586-
end_token_index += len(spec_token_ids)
576+
577+
if not is_last_rank:
578+
# Add new_token_ids to token_ids_cpu.
579+
start_token_index = num_computed_tokens
580+
end_token_index = num_computed_tokens + len(new_token_ids)
587581
self.input_batch.token_ids_cpu[
588582
req_index,
589-
start_index:end_token_index] = spec_token_ids
590-
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
591-
self.input_batch.num_tokens[req_index] = end_token_index
583+
start_token_index:end_token_index] = new_token_ids
584+
self.input_batch.num_tokens_no_spec[
585+
req_index] = end_token_index
586+
# Add spec_token_ids to token_ids_cpu.
587+
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
588+
req_id, ())
589+
if spec_token_ids:
590+
start_index = end_token_index
591+
end_token_index += len(spec_token_ids)
592+
self.input_batch.token_ids_cpu[
593+
req_index,
594+
start_index:end_token_index] = spec_token_ids
595+
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
596+
self.input_batch.num_tokens[req_index] = end_token_index
592597

593598
# Check if the batch has changed. If not, we can skip copying the
594599
# sampling metadata from CPU to GPU.
@@ -1641,6 +1646,30 @@ def execute_model(
16411646

16421647
for i in discard_sampled_tokens_req_indices:
16431648
valid_sampled_token_ids[i].clear()
1649+
if not vllm_version_is("0.9.1"):
1650+
# Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions
1651+
# doesn't need to send them back.
1652+
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
1653+
# the sampled tokens back, because there's no direct communication
1654+
# between the first-stage worker and the last-stage worker.
1655+
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
1656+
if not sampled_ids:
1657+
continue
1658+
1659+
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
1660+
end_idx = start_idx + len(sampled_ids)
1661+
assert end_idx <= self.model_config.max_model_len, (
1662+
"Sampled token IDs exceed the max model length. "
1663+
f"Total number of tokens: {end_idx} > max_model_len: "
1664+
f"{self.model_config.max_model_len}")
1665+
1666+
self.input_batch.token_ids_cpu[
1667+
req_idx, start_idx:end_idx] = sampled_ids
1668+
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
1669+
self.input_batch.num_tokens[req_idx] = end_idx
1670+
req_id = self.input_batch.req_ids[req_idx]
1671+
req_state = self.requests[req_id]
1672+
req_state.output_token_ids.extend(sampled_ids)
16441673

16451674
spec_token_ids = self._get_spec_token_ids(
16461675
valid_sampled_token_ids,

0 commit comments

Comments
 (0)