15
15
# This file is a part of the vllm-ascend project.
16
16
#
17
17
from collections import deque
18
- from typing import Iterable , Optional , Union
18
+ from typing import Iterable , Union
19
19
20
- from vllm .config import VllmConfig
21
20
from vllm .logger import logger
22
- from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalRegistry
23
21
from vllm .utils import cdiv
24
22
from vllm .v1 .core .sched .output import NewRequestData , SchedulerOutput
25
23
from vllm .v1 .core .sched .scheduler import Scheduler
26
- from vllm .v1 .core .sched .utils import check_stop
27
- from vllm .v1 .engine import EngineCoreOutput , EngineCoreOutputs
28
- from vllm .v1 .kv_cache_interface import KVCacheConfig
24
+ from vllm .v1 .engine import EngineCoreOutputs
29
25
from vllm .v1 .outputs import ModelRunnerOutput
30
26
from vllm .v1 .request import Request , RequestStatus
31
- from vllm .v1 .spec_decode .metrics import SpecDecodingStats
32
- from vllm .v1 .structured_output import StructuredOutputManager
33
27
34
28
35
29
class AscendScheduler (Scheduler ):
36
30
"""This Scheduler extends vllm's original v1 scheduler
37
31
with prefill-first scheduling strategy."""
38
32
39
- def __init__ (
40
- self ,
41
- vllm_config : VllmConfig ,
42
- kv_cache_config : KVCacheConfig ,
43
- structured_output_manager : StructuredOutputManager ,
44
- mm_registry : MultiModalRegistry = MULTIMODAL_REGISTRY ,
45
- include_finished_set : bool = False ,
46
- log_stats : bool = False ,
47
- ) -> None :
48
- super ().__init__ (vllm_config , kv_cache_config ,
49
- structured_output_manager , mm_registry ,
50
- include_finished_set , log_stats )
33
+ def __init__ (self , * args , ** kwargs ) -> None :
34
+ super ().__init__ (* args , ** kwargs )
51
35
self .scheduled_req_ids : set [str ] = set ()
52
36
self .running : list [Request ] = []
53
37
@@ -365,41 +349,22 @@ def finish_requests(
365
349
For example, the API server can abort a request when the client
366
350
disconnects.
367
351
"""
368
- assert RequestStatus .is_finished (finished_status )
369
- if isinstance (request_ids , str ):
370
- request_ids = (request_ids , )
371
- else :
372
- request_ids = set (request_ids )
373
-
374
352
for req_id in request_ids :
375
353
request = self .requests .get (req_id )
376
354
if request is None :
377
355
# Invalid request ID.
378
356
continue
379
-
380
357
if request .status == RequestStatus .RUNNING :
381
- self .running .remove (request )
382
358
self .scheduled_req_ids .discard (request .request_id )
383
- else :
384
- self .waiting .remove (request )
385
- request .status = finished_status
386
- self ._free_request (request )
359
+ super ().finish_requests (request_ids , finished_status )
387
360
388
361
def update_from_output (
389
362
self ,
390
363
scheduler_output : SchedulerOutput ,
391
364
model_runner_output : ModelRunnerOutput ,
392
365
) -> EngineCoreOutputs :
393
- sampled_token_ids = model_runner_output .sampled_token_ids
394
- spec_token_ids = model_runner_output .spec_token_ids
395
- logprobs = model_runner_output .logprobs
396
- prompt_logprobs_dict = model_runner_output .prompt_logprobs_dict
397
366
num_scheduled_tokens = scheduler_output .num_scheduled_tokens
398
367
399
- new_running : list [Request ] = []
400
- outputs : list [EngineCoreOutput ] = []
401
- spec_decoding_stats : Optional [SpecDecodingStats ] = None
402
-
403
368
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
404
369
# loop can be a performance bottleneck. We should do our best to avoid
405
370
# expensive operations inside the loop.
@@ -408,121 +373,8 @@ def update_from_output(
408
373
num_tokens_scheduled = num_scheduled_tokens .get (req_id , 0 )
409
374
if num_tokens_scheduled == 0 :
410
375
# The request was not scheduled in this step.
411
- new_running .append (request )
412
376
continue
413
-
414
- req_index = model_runner_output .req_id_to_index [req_id ]
415
- generated_token_ids = sampled_token_ids [req_index ]
416
-
417
- scheduled_spec_token_ids = (
418
- scheduler_output .scheduled_spec_decode_tokens .get (req_id ))
419
- if scheduled_spec_token_ids :
420
- # num_computed_tokens represents the number of tokens
421
- # processed in the current step, considering scheduled
422
- # tokens and rejections. If some tokens are rejected,
423
- # num_computed_tokens is decreased by the number of rejected
424
- # tokens, where is given by:
425
- # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
426
- num_tokens_rejected = (len (scheduled_spec_token_ids ) + 1 -
427
- len (generated_token_ids ))
428
- request .num_computed_tokens -= num_tokens_rejected
429
- spec_decoding_stats = self .make_spec_decoding_stats (
430
- spec_decoding_stats ,
431
- num_draft_tokens = len (scheduled_spec_token_ids ),
432
- num_accepted_tokens = len (generated_token_ids ) - 1 )
433
-
434
- cached_encoder_input_ids = (
435
- self .encoder_cache_manager .get_cached_input_ids (request ))
436
- # OPTIMIZATION: Avoid list(set) if the set is empty.
437
- if cached_encoder_input_ids :
438
- for input_id in list (cached_encoder_input_ids ):
439
- mm_positions = request .mm_positions [input_id ]
440
- start_pos = mm_positions .offset
441
- num_tokens = mm_positions .length
442
- if start_pos + num_tokens <= request .num_computed_tokens :
443
- # The encoder output is already processed and stored
444
- # in the decoder's KV cache.
445
- self .encoder_cache_manager .free_encoder_input (
446
- request , input_id )
447
-
448
- stopped = False
449
- new_logprobs = None
450
- new_token_ids = generated_token_ids
451
-
452
- # Append generated tokens and check for stop. Note that if
453
- # a request is still being prefilled, we expect the model runner
454
- # to return empty token ids for the request.
455
- for num_new , output_token_id in enumerate (new_token_ids , 1 ):
456
- request .append_output_token_ids (output_token_id )
457
-
458
- # Check for stop and update request state.
459
- # This must be called before we make the EngineCoreOutput.
460
- stopped = check_stop (request , self .max_model_len )
461
- if stopped :
462
- self ._free_request (request )
463
- del new_token_ids [num_new :] # Trim new tokens if needed.
464
- break
465
-
466
- # Extract sample logprobs if needed.
467
- if request .sampling_params .logprobs is not None and logprobs :
468
- # NOTE: once we support N tokens per step (spec decode),
469
- # the outer lists can be of length > 1.
470
- new_logprobs = logprobs .slice (req_index , req_index + 1 )
471
-
472
- if new_token_ids and request .use_structured_output :
473
- # NOTE: structured_output_request
474
- # should not be None if use_structured_output, we have
475
- # check above, so safe to ignore type warning
476
- request .structured_output_request .grammar .accept_tokens ( # type: ignore[union-attr]
477
- req_id , new_token_ids )
478
-
479
- # Add newly generated spec token ids to the request.
480
- if spec_token_ids is not None :
481
- if request .use_structured_output :
482
- metadata = request .structured_output_request
483
- assert metadata is not None and metadata .grammar is not None
484
- # Needs to happen after new_token_ids are accepted.
485
- request .spec_token_ids = metadata .grammar .validate_tokens (
486
- spec_token_ids [req_index ])
487
- else :
488
- request .spec_token_ids = spec_token_ids [req_index ]
489
-
490
- # Get prompt logprobs for this request.
491
- prompt_logprobs_tensors = prompt_logprobs_dict .get (req_id )
492
- if new_token_ids :
493
- # Add EngineCoreOutput for this Request.
494
- outputs .append (
495
- EngineCoreOutput (
496
- request_id = req_id ,
497
- new_token_ids = new_token_ids ,
498
- finish_reason = request .get_finished_reason (),
499
- new_logprobs = new_logprobs ,
500
- new_prompt_logprobs_tensors = prompt_logprobs_tensors ,
501
- stop_reason = request .stop_reason ,
502
- events = request .take_events ()))
503
- else :
504
- # Invariant: EngineCore returns no partial prefill outputs.
505
- assert not prompt_logprobs_tensors
506
-
507
377
self .scheduled_req_ids .remove (req_id )
508
- if not stopped :
509
- new_running .append (request )
510
-
511
- # Return the cached request data to the queue so they can be reused.
512
- for req_data in scheduler_output .scheduled_cached_reqs :
513
- # NOTE(rob): since we free stopped reqs above, adding stopped reqs
514
- # to _cached_reqs_data will cause a memory leak.
515
- if req_data .req_id not in self .finished_req_ids :
516
- self ._cached_reqs_data [req_data .req_id ].append (req_data )
517
-
518
- self .running = new_running
519
- engine_core_outputs = EngineCoreOutputs (
520
- outputs = outputs ,
521
- scheduler_stats = self .make_stats (spec_decoding_stats ),
522
- )
523
- if self .include_finished_set :
524
- #TODO currently sending duplicates here, improve this
525
- engine_core_outputs .finished_requests = (
526
- scheduler_output .finished_req_ids | self .finished_req_ids )
527
378
528
- return engine_core_outputs
379
+ return super ().update_from_output (scheduler_output ,
380
+ model_runner_output )
0 commit comments