Skip to content

Commit 0e43813

Browse files
ganyi1996ppoYikun
andauthored
[ModelRunner] Use shared CachedRequestData cross request to fix ci (#1546)
### What this PR does / why we need it? This PR (adapted from vllm-project/vllm@2863bef) updates the CachedRequestData definition to use a single instance shared across all requests in a batch, instead of creating a new instance per request. Found ci boken by the vllm's model_runner change: `ERROR 07-01 09:53:53 [core.py:521] TypeError: 'CachedRequestData' object is not iterable`, Modify the model_runner to fix it. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? pass ci will verify this. --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
1 parent 6db7dc2 commit 0e43813

File tree

4 files changed

+181
-87
lines changed

4 files changed

+181
-87
lines changed

tests/e2e/singlecard/core/test_ascend_scheduler.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,10 @@ def test_schedule(enable_prefix_caching: Optional[bool],
201201
# Test initial scheduling
202202
output = scheduler.schedule()
203203
assert len(output.scheduled_new_reqs) == len(requests)
204-
assert len(output.scheduled_cached_reqs) == 0
204+
if vllm_version_is("0.9.1"):
205+
assert len(output.scheduled_cached_reqs) == 0
206+
else:
207+
assert output.scheduled_cached_reqs.num_reqs == 0
205208
assert len(output.finished_req_ids) == 0
206209
# Verify all requests are scheduled.
207210
for req_id, num_tokens in output.num_scheduled_tokens.items():
@@ -238,7 +241,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
238241

239242
output = scheduler.schedule()
240243
assert len(output.scheduled_new_reqs) == 3
241-
assert len(output.scheduled_cached_reqs) == 0
244+
if vllm_version_is("0.9.1"):
245+
assert len(output.scheduled_cached_reqs) == 0
246+
else:
247+
assert output.scheduled_cached_reqs.num_reqs == 0
242248
assert len(output.finished_req_ids) == 0
243249

244250
# The first request is scheduled partially - 400.
@@ -268,7 +274,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
268274
output1 = scheduler.schedule()
269275
assert len(scheduler.running) == 3
270276
assert len(output1.scheduled_new_reqs) == 0
271-
assert len(output1.scheduled_cached_reqs) == 3
277+
if vllm_version_is("0.9.1"):
278+
assert len(output1.scheduled_cached_reqs) == 3
279+
else:
280+
assert output1.scheduled_cached_reqs.num_reqs == 3
272281
assert len(output1.finished_req_ids) == 0
273282
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
274283
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
@@ -292,7 +301,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
292301
output2 = scheduler.schedule()
293302
assert len(scheduler.running) == 3
294303
assert len(output2.scheduled_new_reqs) == 0
295-
assert len(output2.scheduled_cached_reqs) == 3
304+
if vllm_version_is("0.9.1"):
305+
assert len(output2.scheduled_cached_reqs) == 3
306+
else:
307+
assert output2.scheduled_cached_reqs.num_reqs == 3
296308
assert len(output2.finished_req_ids) == 0
297309
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
298310
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
@@ -762,7 +774,6 @@ def assert_scheduler_empty(scheduler: AscendScheduler):
762774
assert len(scheduler.waiting) == 0
763775
assert len(scheduler.running) == 0
764776
assert len(scheduler.finished_req_ids) == 0
765-
assert len(scheduler._cached_reqs_data) == 0
766777

767778
# EncoderCacheManager.
768779
assert len(scheduler.encoder_cache_manager.freed) == 0

tests/e2e/singlecard/test_scheduler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,10 @@ def test_schedule(enable_prefix_caching: Optional[bool],
192192
# Test initial scheduling
193193
output = scheduler.schedule()
194194
assert len(output.scheduled_new_reqs) == len(requests)
195-
assert len(output.scheduled_cached_reqs) == 0
195+
if vllm_version_is("0.9.1"):
196+
assert len(output.scheduled_cached_reqs) == 0
197+
else:
198+
assert output.scheduled_cached_reqs.num_reqs == 0
196199
assert len(output.finished_req_ids) == 0
197200
# Verify all requests are scheduled.
198201
for req_id, num_tokens in output.num_scheduled_tokens.items():

vllm_ascend/core/scheduler.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from vllm.v1.request import Request, RequestStatus
3333
from vllm.v1.structured_output import StructuredOutputManager
3434

35+
from vllm_ascend.utils import vllm_version_is
36+
3537

3638
class AscendScheduler(Scheduler):
3739
"""This Scheduler extends vllm's original v1 scheduler
@@ -364,27 +366,36 @@ def skip_cur_request():
364366
req_to_new_block_ids[req.request_id])
365367
for req in scheduled_new_reqs
366368
]
367-
resumed_reqs_data = [
368-
self._make_cached_request_data(
369-
req,
370-
num_scheduled_tokens[req.request_id],
371-
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
372-
req_to_new_block_ids[req.request_id],
373-
resumed_from_preemption=True,
374-
) for req in scheduled_resumed_reqs
375-
]
376-
running_reqs_data = [
377-
self._make_cached_request_data(
378-
req,
379-
num_scheduled_tokens[req.request_id],
380-
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
381-
req_to_new_block_ids[req.request_id],
382-
resumed_from_preemption=False,
383-
) for req in scheduled_running_reqs
384-
]
369+
if vllm_version_is("0.9.1"):
370+
resumed_reqs_data = [
371+
self._make_cached_request_data(
372+
req,
373+
num_scheduled_tokens[req.request_id],
374+
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
375+
req_to_new_block_ids[req.request_id],
376+
resumed_from_preemption=True,
377+
) for req in scheduled_resumed_reqs
378+
]
379+
running_reqs_data = [
380+
self._make_cached_request_data(
381+
req,
382+
num_scheduled_tokens[req.request_id],
383+
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
384+
req_to_new_block_ids[req.request_id],
385+
resumed_from_preemption=False,
386+
) for req in scheduled_running_reqs
387+
]
388+
scheduled_cached_reqs = resumed_reqs_data + running_reqs_data
389+
else:
390+
cached_reqs_data = self._make_cached_request_data(
391+
scheduled_running_reqs, scheduled_resumed_reqs,
392+
num_scheduled_tokens, scheduled_spec_decode_tokens,
393+
req_to_new_block_ids)
394+
scheduled_cached_reqs = cached_reqs_data
395+
385396
scheduler_output = SchedulerOutput(
386397
scheduled_new_reqs=new_reqs_data,
387-
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
398+
scheduled_cached_reqs=scheduled_cached_reqs,
388399
num_scheduled_tokens=num_scheduled_tokens,
389400
total_num_scheduled_tokens=total_num_scheduled_tokens,
390401
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 131 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -456,78 +456,147 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
456456
req_ids_to_add.append(req_id)
457457

458458
# Update the states of the running/resumed requests.
459-
for req_data in scheduler_output.scheduled_cached_reqs:
460-
req_id = req_data.req_id
461-
req_state = self.requests[req_id]
459+
if vllm_version_is("0.9.1"):
460+
for req_data in scheduler_output.scheduled_cached_reqs:
461+
req_id = req_data.req_id
462+
req_state = self.requests[req_id]
462463

463-
# Update the cached states.
464-
num_computed_tokens = req_data.num_computed_tokens
465-
req_state.num_computed_tokens = num_computed_tokens
466-
# Add the sampled token(s) from the previous step (if any).
467-
# This doesn't include "unverified" tokens like spec decode tokens.
468-
num_new_tokens = (num_computed_tokens +
469-
len(req_data.new_token_ids) -
470-
req_state.num_tokens)
471-
if num_new_tokens == 1:
472-
# Avoid slicing list in most common case.
473-
req_state.output_token_ids.append(req_data.new_token_ids[-1])
474-
elif num_new_tokens > 0:
475-
req_state.output_token_ids.extend(
476-
req_data.new_token_ids[-num_new_tokens:])
477-
# Update the block IDs.
478-
if not req_data.resumed_from_preemption:
479-
# Append the new blocks to the existing block IDs.
480-
for block_ids, new_block_ids in zip( # type: ignore[call-overload]
481-
req_state.block_ids,
482-
req_data.new_block_ids,
483-
strict=True):
484-
block_ids.extend(new_block_ids)
485-
else:
486-
# The request is resumed from preemption.
487-
# Replace the existing block IDs with the new ones.
488-
req_state.block_ids = req_data.new_block_ids
489-
490-
req_index = self.input_batch.req_id_to_index.get(req_id)
491-
if req_index is None:
492-
# The request is not in the persistent batch.
493-
# The request was either preempted and resumed later, or was not
494-
# scheduled in the previous step and needs to be added again.
495-
req_ids_to_add.append(req_id)
496-
continue
464+
# Update the cached states.
465+
num_computed_tokens = req_data.num_computed_tokens
466+
req_state.num_computed_tokens = num_computed_tokens
467+
# Add the sampled token(s) from the previous step (if any).
468+
# This doesn't include "unverified" tokens like spec decode tokens.
469+
num_new_tokens = (num_computed_tokens +
470+
len(req_data.new_token_ids) -
471+
req_state.num_tokens)
472+
if num_new_tokens == 1:
473+
# Avoid slicing list in most common case.
474+
req_state.output_token_ids.append(
475+
req_data.new_token_ids[-1])
476+
elif num_new_tokens > 0:
477+
req_state.output_token_ids.extend(
478+
req_data.new_token_ids[-num_new_tokens:])
479+
# Update the block IDs.
480+
if not req_data.resumed_from_preemption:
481+
# Append the new blocks to the existing block IDs.
482+
for block_ids, new_block_ids in zip( # type: ignore[call-overload]
483+
req_state.block_ids,
484+
req_data.new_block_ids,
485+
strict=True):
486+
block_ids.extend(new_block_ids)
487+
else:
488+
# The request is resumed from preemption.
489+
# Replace the existing block IDs with the new ones.
490+
req_state.block_ids = req_data.new_block_ids
491+
492+
req_index = self.input_batch.req_id_to_index.get(req_id)
493+
if req_index is None:
494+
# The request is not in the persistent batch.
495+
# The request was either preempted and resumed later, or was not
496+
# scheduled in the previous step and needs to be added again.
497+
req_ids_to_add.append(req_id)
498+
continue
499+
500+
# Update the persistent batch.
501+
self.input_batch.num_computed_tokens_cpu[req_index] = (
502+
num_computed_tokens)
503+
504+
start_index = (len(req_state.block_ids) -
505+
len(req_data.new_block_ids))
506+
self.input_batch.block_table.append_row(
507+
req_data.new_block_ids, req_index)
508+
# Add new_token_ids to token_ids_cpu.
509+
start_token_index = num_computed_tokens
510+
end_token_index = num_computed_tokens + len(
511+
req_data.new_token_ids)
512+
self.input_batch.token_ids_cpu[
513+
req_index,
514+
start_token_index:end_token_index] = req_data.new_token_ids
515+
self.input_batch.num_tokens_no_spec[
516+
req_index] = end_token_index
517+
# Add spec_token_ids to token_ids_cpu.
518+
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
519+
req_id, ())
520+
if spec_token_ids:
521+
start_index = end_token_index
522+
end_token_index += len(spec_token_ids)
523+
self.input_batch.token_ids_cpu[
524+
req_index,
525+
start_index:end_token_index] = spec_token_ids
526+
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
527+
self.input_batch.num_tokens[req_index] = end_token_index
528+
else:
529+
req_data = scheduler_output.scheduled_cached_reqs
530+
for i, req_id in enumerate(req_data.req_ids):
531+
req_state = self.requests[req_id]
532+
num_computed_tokens = req_data.num_computed_tokens[i]
533+
new_token_ids = req_data.new_token_ids[i]
534+
new_block_ids = req_data.new_block_ids[i]
535+
resumed_from_preemption = req_data.resumed_from_preemption[i]
536+
537+
req_state.num_computed_tokens = num_computed_tokens
538+
# Add the sampled token(s) from the previous step (if any).
539+
# This doesn't include "unverified" tokens like spec decode tokens.
540+
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
541+
req_state.num_tokens)
542+
if num_new_tokens == 1:
543+
# Avoid slicing list in most common case.
544+
req_state.output_token_ids.append(new_token_ids[-1])
545+
elif num_new_tokens > 0:
546+
req_state.output_token_ids.extend(
547+
new_token_ids[-num_new_tokens:])
548+
# Update the block IDs.
549+
if not resumed_from_preemption:
550+
# Append the new blocks to the existing block IDs.
551+
for block_ids, new_ids in zip( # type: ignore[call-overload]
552+
req_state.block_ids, new_block_ids):
553+
block_ids.extend(new_ids)
554+
else:
555+
# The request is resumed from preemption.
556+
# Replace the existing block IDs with the new ones.
557+
req_state.block_ids = new_block_ids
558+
559+
req_index = self.input_batch.req_id_to_index.get(req_id)
560+
if req_index is None:
561+
# The request is not in the persistent batch.
562+
# The request was either preempted and resumed later, or was not
563+
# scheduled in the previous step and needs to be added again.
564+
req_ids_to_add.append(req_id)
565+
continue
566+
567+
# Update the persistent batch.
568+
self.input_batch.num_computed_tokens_cpu[req_index] = (
569+
num_computed_tokens)
497570

498-
# Update the persistent batch.
499-
self.input_batch.num_computed_tokens_cpu[req_index] = (
500-
num_computed_tokens)
501-
502-
start_index = (len(req_state.block_ids) -
503-
len(req_data.new_block_ids))
504-
self.input_batch.block_table.append_row(req_data.new_block_ids,
505-
req_index)
506-
# Add new_token_ids to token_ids_cpu.
507-
start_token_index = num_computed_tokens
508-
end_token_index = num_computed_tokens + len(req_data.new_token_ids)
509-
self.input_batch.token_ids_cpu[
510-
req_index,
511-
start_token_index:end_token_index] = req_data.new_token_ids
512-
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
513-
# Add spec_token_ids to token_ids_cpu.
514-
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
515-
req_id, ())
516-
if spec_token_ids:
517-
start_index = end_token_index
518-
end_token_index += len(spec_token_ids)
571+
self.input_batch.block_table.append_row(
572+
new_block_ids, req_index)
573+
# Add new_token_ids to token_ids_cpu.
574+
start_token_index = num_computed_tokens
575+
end_token_index = num_computed_tokens + len(new_token_ids)
519576
self.input_batch.token_ids_cpu[
520-
req_index, start_index:end_token_index] = spec_token_ids
521-
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
522-
self.input_batch.num_tokens[req_index] = end_token_index
577+
req_index,
578+
start_token_index:end_token_index] = new_token_ids
579+
self.input_batch.num_tokens_no_spec[
580+
req_index] = end_token_index
581+
# Add spec_token_ids to token_ids_cpu.
582+
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
583+
req_id, ())
584+
if spec_token_ids:
585+
start_index = end_token_index
586+
end_token_index += len(spec_token_ids)
587+
self.input_batch.token_ids_cpu[
588+
req_index,
589+
start_index:end_token_index] = spec_token_ids
590+
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
591+
self.input_batch.num_tokens[req_index] = end_token_index
523592

524593
# Check if the batch has changed. If not, we can skip copying the
525594
# sampling metadata from CPU to GPU.
526595
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
527596

528597
# Add the new or resumed requests to the persistent batch.
529598
# The smaller empty indices are filled first.
530-
removed_req_indices = sorted(removed_req_indices, reverse=True)
599+
removed_req_indices.sort(reverse=True)
531600
for req_id in req_ids_to_add:
532601
req_state = self.requests[req_id]
533602
if removed_req_indices:

0 commit comments

Comments
 (0)