Skip to content

Commit 5627cae

Browse files
committed
[Core] Implement sleep/wake_up for SpecDecodeWorker
Signed-off-by: Kebe <mail@kebe7jun.com>
1 parent 359200f commit 5627cae

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

vllm/spec_decode/spec_decode_worker.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch.nn as nn
1111

1212
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
13+
from vllm.device_allocator.cumem import CuMemAllocator
1314
from vllm.distributed.communication_op import (broadcast_tensor_dict,
1415
get_tp_group,
1516
tensor_model_parallel_gather)
@@ -27,6 +28,7 @@
2728
HiddenStates, SequenceGroupMetadata,
2829
get_all_seq_ids_and_request_ids)
2930
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
31+
from vllm.utils import GiB_bytes
3032

3133
if current_platform.is_cuda_alike():
3234
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
@@ -388,6 +390,8 @@ def init_device(self) -> None:
388390
vocab_size=self._vocab_size)
389391

390392
self._configure_model_sampler_for_spec_decode()
393+
self._scorer_sleep_saved_buffer = {}
394+
self._proposer_sleep_saved_buffer = {}
391395

392396
def load_model(self, *args, **kwargs):
393397
pass
@@ -1290,6 +1294,52 @@ def stop_profile(self):
12901294
if isinstance(self.scorer_worker, WorkerBase):
12911295
self.scorer_worker.stop_profile()
12921296

1297+
def sleep(self, level=1):
1298+
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
1299+
1300+
# Save the buffers before level 2 sleep
1301+
if level == 2:
1302+
scorer_model = self.scorer_worker.get_model().model
1303+
self._scorer_sleep_saved_buffer = {
1304+
name: buffer.cpu().clone()
1305+
for name, buffer in scorer_model.named_buffers()
1306+
}
1307+
proposer_model = self.proposer_worker.get_model().model
1308+
self._proposer_sleep_saved_buffer = {
1309+
name: buffer.cpu().clone()
1310+
for name, buffer in proposer_model.named_buffers()
1311+
}
1312+
allocator = CuMemAllocator.get_instance()
1313+
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
1314+
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
1315+
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
1316+
used_bytes = total - free_bytes_after_sleep
1317+
assert freed_bytes >= 0, "Memory usage increased after sleeping."
1318+
logger.info(
1319+
"Sleep mode freed %.2f GiB memory, "
1320+
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
1321+
used_bytes / GiB_bytes)
1322+
1323+
def wake_up(self, tags: Optional[list[str]] = None) -> None:
1324+
allocator = CuMemAllocator.get_instance()
1325+
allocator.wake_up(tags=tags)
1326+
1327+
# Restore the buffers after level 2 sleep
1328+
if len(self._scorer_sleep_saved_buffer):
1329+
model = self.scorer_worker.get_model().model
1330+
for name, buffer in model.named_buffers():
1331+
if name in self._scorer_sleep_saved_buffer:
1332+
buffer.data.copy_(
1333+
self._scorer_sleep_saved_buffer[name].data)
1334+
self._scorer_sleep_saved_buffer = {}
1335+
if len(self._proposer_sleep_saved_buffer):
1336+
model = self.proposer_worker.get_model().model
1337+
for name, buffer in model.named_buffers():
1338+
if name in self._proposer_sleep_saved_buffer:
1339+
buffer.data.copy_(
1340+
self._proposer_sleep_saved_buffer[name].data)
1341+
self._proposer_sleep_saved_buffer = {}
1342+
12931343

12941344
def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
12951345
proposer_cache_block_size_bytes: int,

0 commit comments

Comments
 (0)