Skip to content

Commit 9206b3d

Browse files
authored
[V1][PP] Run engine busy loop with batch queue (#13064)
1 parent ed0de3e commit 9206b3d

File tree

6 files changed

+299
-15
lines changed

6 files changed

+299
-15
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,54 @@ def test_schedule_partial_requests():
213213
assert output.num_scheduled_tokens[requests[0].request_id] == 1
214214
assert output.num_scheduled_tokens[requests[1].request_id] == 700
215215
assert requests[2].request_id not in output.num_scheduled_tokens
216+
217+
218+
def test_schedule_concurrent_batches():
219+
scheduler = create_scheduler(
220+
max_num_batched_tokens=1024,
221+
max_num_seqs=2,
222+
)
223+
requests = create_requests(
224+
num_requests=2,
225+
num_tokens=512,
226+
)
227+
228+
# Schedule the first request.
229+
scheduler.add_request(requests[0])
230+
scheduler_output0 = scheduler.schedule()
231+
assert len(scheduler_output0.scheduled_new_reqs) == 1
232+
assert scheduler_output0.num_scheduled_tokens[
233+
requests[0].request_id] == 512
234+
235+
# The first request is still running, so only schedule the second request.
236+
scheduler.add_request(requests[1])
237+
scheduler_output1 = scheduler.schedule()
238+
assert len(scheduler_output1.scheduled_new_reqs) == 1
239+
assert scheduler_output1.num_scheduled_tokens[
240+
requests[1].request_id] == 512
241+
242+
# Model output of the first request.
243+
model_runner_output = ModelRunnerOutput(
244+
req_ids=[requests[0].request_id],
245+
req_id_to_index={requests[0].request_id: 0},
246+
sampled_token_ids=[0],
247+
logprobs=None,
248+
prompt_logprobs_dict={},
249+
)
250+
scheduler.update_from_output(scheduler_output0, model_runner_output)
251+
252+
# Schedule the next step.
253+
# The first request can be scheduled again while the second
254+
# request is still running.
255+
scheduler_output2 = scheduler.schedule()
256+
assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1
257+
258+
# Model output of the second request.
259+
model_runner_output = ModelRunnerOutput(
260+
req_ids=[requests[1].request_id],
261+
req_id_to_index={requests[1].request_id: 0},
262+
sampled_token_ids=[0],
263+
logprobs=None,
264+
prompt_logprobs_dict={},
265+
)
266+
scheduler.update_from_output(scheduler_output1, model_runner_output)

tests/v1/engine/test_engine_core.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import copy
4+
import threading
35
import time
46
import uuid
7+
from concurrent.futures import Future
58

69
import pytest
710
from transformers import AutoTokenizer
@@ -12,7 +15,9 @@
1215
from vllm.platforms import current_platform
1316
from vllm.v1.engine import EngineCoreRequest
1417
from vllm.v1.engine.core import EngineCore
15-
from vllm.v1.executor.abstract import Executor
18+
from vllm.v1.executor.abstract import Executor, UniProcExecutor
19+
from vllm.v1.kv_cache_interface import KVCacheConfig
20+
from vllm.v1.outputs import ModelRunnerOutput
1621

1722
if not current_platform.is_cuda():
1823
pytest.skip(reason="V1 currently only supported on CUDA.",
@@ -191,3 +196,85 @@ def _check_engine_state():
191196
)
192197
engine_core.add_request(request2)
193198
_check_engine_state()
199+
200+
201+
@fork_new_process_for_each_test
202+
def test_engine_core_concurrent_batches(monkeypatch):
203+
"""
204+
Test that the engine can handle multiple concurrent batches.
205+
"""
206+
207+
def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest:
208+
request = make_request()
209+
request.sampling_params.max_tokens = max_tokens
210+
return request
211+
212+
class DummyExecutor(UniProcExecutor):
213+
214+
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
215+
super().initialize(kv_cache_config)
216+
217+
# This executor actually can only run 1 batch at a time
218+
self.semaphore = threading.Semaphore(1)
219+
220+
def execute_model(
221+
self,
222+
scheduler_output,
223+
) -> Future[ModelRunnerOutput]:
224+
"""Make execute_model non-blocking."""
225+
future: Future[ModelRunnerOutput] = Future()
226+
227+
def _thread_wrapper(scheduler_output, future):
228+
with self.semaphore:
229+
output = self.collective_rpc("execute_model",
230+
args=(scheduler_output, ))
231+
# Make a copy because output[0] may be reused
232+
# by the next batch.
233+
output = copy.deepcopy(output[0])
234+
future.set_result(output)
235+
236+
threading.Thread(target=_thread_wrapper,
237+
args=(scheduler_output, future)).start()
238+
return future
239+
240+
@property
241+
def max_concurrent_batches(self) -> int:
242+
return 2
243+
244+
with monkeypatch.context() as m:
245+
m.setenv("VLLM_USE_V1", "1")
246+
247+
engine_args = EngineArgs(
248+
model=MODEL_NAME,
249+
# To test concurrent batches.
250+
max_num_seqs=2,
251+
# Avoid all requests being scheduled once.
252+
enable_prefix_caching=False,
253+
max_num_batched_tokens=10,
254+
)
255+
vllm_config = engine_args.create_engine_config()
256+
engine_core = EngineCore(vllm_config=vllm_config,
257+
log_stats=False,
258+
executor_class=DummyExecutor)
259+
assert engine_core.batch_queue is not None
260+
261+
# Add two requests in a row.
262+
req = make_request_with_max_tokens(5)
263+
engine_core.add_request(req)
264+
req = make_request_with_max_tokens(5)
265+
engine_core.add_request(req)
266+
267+
# First saturate the batch queue.
268+
assert engine_core.step_with_batch_queue() is None
269+
assert engine_core.batch_queue.qsize() == 1
270+
assert engine_core.step_with_batch_queue() is None
271+
assert engine_core.batch_queue.qsize() == 2
272+
assert engine_core.scheduler.get_num_unfinished_requests() == 2
273+
274+
# Loop through both requests.
275+
while engine_core.scheduler.get_num_unfinished_requests() == 2:
276+
engine_core.step_with_batch_queue()
277+
278+
# Reaching here when got the result of the first request.
279+
while engine_core.scheduler.get_num_unfinished_requests() == 1:
280+
engine_core.step_with_batch_queue()

vllm/v1/core/scheduler.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def __init__(
5858
# Priority queues for requests.
5959
self.waiting: Deque[Request] = deque()
6060
self.running: List[Request] = []
61+
# The requests that have been scheduled and are being executed
62+
# by the executor.
63+
self.scheduled_req_ids: Set[str] = set()
6164

6265
# The request IDs that are finished in between the previous and the
6366
# current steps. This is used to notify the workers about the finished
@@ -118,6 +121,11 @@ def schedule(self) -> "SchedulerOutput":
118121
req_index = 0
119122
while req_index < len(self.running) and token_budget > 0:
120123
request = self.running[req_index]
124+
if request.request_id in self.scheduled_req_ids:
125+
# This request has already been scheduled.
126+
req_index += 1
127+
continue
128+
121129
num_new_tokens = request.num_tokens - request.num_computed_tokens
122130
num_new_tokens = min(num_new_tokens, token_budget)
123131
assert num_new_tokens > 0
@@ -164,6 +172,7 @@ def schedule(self) -> "SchedulerOutput":
164172

165173
# Schedule the request.
166174
scheduled_running_reqs.append(request)
175+
self.scheduled_req_ids.add(request.request_id)
167176
req_to_new_block_ids[request.request_id] = [
168177
b.block_id for b in new_blocks
169178
]
@@ -251,6 +260,7 @@ def schedule(self) -> "SchedulerOutput":
251260

252261
self.waiting.popleft()
253262
self.running.append(request)
263+
self.scheduled_req_ids.add(request.request_id)
254264
if request.status == RequestStatus.WAITING:
255265
scheduled_new_reqs.append(request)
256266
self.request_scheduled(request, scheduled_timestamp)
@@ -519,6 +529,7 @@ def update_from_output(
519529
stop_reason=request.stop_reason,
520530
events=request.take_events()))
521531

532+
self.scheduled_req_ids.remove(request.request_id)
522533
if not stopped:
523534
new_running.append(request)
524535

@@ -575,6 +586,8 @@ def finish_requests(
575586

576587
if request.status == RequestStatus.RUNNING:
577588
self.running.remove(request)
589+
if request.request_id in self.scheduled_req_ids:
590+
self.scheduled_req_ids.remove(request.request_id)
578591
else:
579592
self.waiting.remove(request)
580593
request.status = finished_status
@@ -595,6 +608,10 @@ def get_num_unfinished_requests(self) -> int:
595608
def has_unfinished_requests(self) -> bool:
596609
return self.get_num_unfinished_requests() > 0
597610

611+
def get_num_unscheduled_requests(self) -> int:
612+
"""Number of requests that are not being processed by the executor."""
613+
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)
614+
598615
def reset_prefix_cache(self) -> bool:
599616
return self.kv_cache_manager.reset_prefix_cache()
600617

vllm/v1/engine/core.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import signal
55
import threading
66
import time
7+
from concurrent.futures import Future
78
from multiprocessing.connection import Connection
8-
from typing import Any, List, Tuple, Type
9+
from typing import Any, List, Optional, Tuple, Type
910

1011
import psutil
1112
import zmq
@@ -18,11 +19,12 @@
1819
maybe_register_config_serialize_by_value)
1920
from vllm.utils import get_exception_traceback, zmq_socket_ctx
2021
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
21-
from vllm.v1.core.scheduler import Scheduler
22+
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
2223
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
2324
EngineCoreRequestType)
2425
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
2526
from vllm.v1.executor.abstract import Executor
27+
from vllm.v1.outputs import ModelRunnerOutput
2628
from vllm.v1.request import Request, RequestStatus
2729
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
2830
from vllm.version import __version__ as VLLM_VERSION
@@ -66,9 +68,22 @@ def __init__(
6668
log_stats=self.log_stats,
6769
)
6870

71+
# Setup MM Input Mapper.
6972
self.mm_input_cache_server = MMInputCacheServer(
7073
vllm_config.model_config)
7174

75+
# Setup batch queue for pipeline parallelism.
76+
# Batch queue for scheduled batches. This enables us to asynchronously
77+
# schedule and execute batches, and is required by pipeline parallelism
78+
# to eliminate pipeline bubbles.
79+
self.batch_queue_size = self.model_executor.max_concurrent_batches
80+
self.batch_queue: Optional[queue.Queue[Tuple[Future[ModelRunnerOutput],
81+
SchedulerOutput]]] = None
82+
if self.batch_queue_size > 1:
83+
logger.info("Batch queue is enabled with size %d",
84+
self.batch_queue_size)
85+
self.batch_queue = queue.Queue(self.batch_queue_size)
86+
7287
def _initialize_kv_caches(self,
7388
vllm_config: VllmConfig) -> Tuple[int, int]:
7489
start = time.time()
@@ -135,7 +150,55 @@ def step(self) -> EngineCoreOutputs:
135150
scheduler_output = self.scheduler.schedule()
136151
output = self.model_executor.execute_model(scheduler_output)
137152
engine_core_outputs = self.scheduler.update_from_output(
138-
scheduler_output, output)
153+
scheduler_output, output) # type: ignore
154+
return engine_core_outputs
155+
156+
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
157+
"""Schedule and execute batches with the batch queue.
158+
Note that if nothing to output in this step, None is returned.
159+
160+
The execution flow is as follows:
161+
1. Try to schedule a new batch if there are unscheduled requests
162+
and the job queue is not full. If a new batch is scheduled, directly
163+
return an empty engine core output. In other words, we won't check
164+
and return model outputs before the batch queue is full.
165+
2. If there is no new scheduled batch, meaning that the batch queue
166+
is full or no other requests can be scheduled, we block until the first
167+
batch in the job queue is finished.
168+
3. Update the scheduler from the output.
169+
"""
170+
assert self.batch_queue is not None
171+
172+
engine_core_outputs = None
173+
scheduler_output = None
174+
# If there are unscheduled requests and the job queue
175+
# is not full, schedule a new batch. Note that this is not blocking.
176+
if (self.scheduler.get_num_unscheduled_requests() > 0
177+
and not self.batch_queue.full()):
178+
scheduler_output = self.scheduler.schedule()
179+
if scheduler_output.total_num_scheduled_tokens > 0:
180+
future = self.model_executor.execute_model(scheduler_output)
181+
self.batch_queue.put_nowait(
182+
(future, scheduler_output)) # type: ignore
183+
184+
# If all requests are scheduled or the job queue is full,
185+
# block until the first batch in the job queue is finished.
186+
if (scheduler_output is None
187+
or scheduler_output.total_num_scheduled_tokens == 0):
188+
try:
189+
future, scheduler_output = self.batch_queue.get(
190+
timeout=POLLING_TIMEOUT_S)
191+
# Blocking until the first result is available.
192+
model_output = future.result()
193+
self.batch_queue.task_done()
194+
engine_core_outputs = self.scheduler.update_from_output(
195+
scheduler_output, model_output)
196+
except queue.Empty:
197+
# If the queue is empty (timeout at .get), return
198+
# an empty EngineCoreOutputs for logging.
199+
engine_core_outputs = EngineCoreOutputs(
200+
outputs=[], scheduler_stats=self.scheduler.make_stats())
201+
139202
return engine_core_outputs
140203

141204
def shutdown(self):
@@ -226,6 +289,9 @@ def signal_handler(signum, frame):
226289
def run_busy_loop(self):
227290
"""Core busy loop of the EngineCore."""
228291

292+
step_fn = (self.step
293+
if self.batch_queue is None else self.step_with_batch_queue)
294+
229295
# Loop until process is sent a SIGINT or SIGTERM
230296
while True:
231297
# 1) Poll the input queue until there is work to do.
@@ -249,10 +315,11 @@ def run_busy_loop(self):
249315
self._handle_client_request(*req)
250316

251317
# 3) Step the engine core.
252-
outputs = self.step()
318+
outputs = step_fn()
253319

254-
# 5) Put EngineCoreOutputs into the output queue.
255-
self.output_queue.put_nowait(outputs)
320+
# 4) Put EngineCoreOutputs into the output queue.
321+
if outputs is not None:
322+
self.output_queue.put_nowait(outputs)
256323

257324
def _handle_client_request(self, request_type: EngineCoreRequestType,
258325
request: Any) -> None:

0 commit comments

Comments
 (0)