From 377e8d2c2d0fbe177e8fb64429d5d620c5c670d8 Mon Sep 17 00:00:00 2001 From: Kebe Date: Thu, 3 Jul 2025 15:45:20 +0800 Subject: [PATCH] [Core] Implement sleep/wake_up for SpecDecodeWorker Signed-off-by: Kebe --- vllm/spec_decode/spec_decode_worker.py | 50 ++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 7dda1cbfe23..3cd42cff541 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -10,6 +10,7 @@ import torch.nn as nn from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig +from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed.communication_op import (broadcast_tensor_dict, get_tp_group, tensor_model_parallel_gather) @@ -27,6 +28,7 @@ HiddenStates, SequenceGroupMetadata, get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer +from vllm.utils import GiB_bytes if current_platform.is_cuda_alike(): from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner @@ -388,6 +390,8 @@ def init_device(self) -> None: vocab_size=self._vocab_size) self._configure_model_sampler_for_spec_decode() + self._scorer_sleep_saved_buffer: Dict[str, torch.Tensor] = {} + self._proposer_sleep_saved_buffer: Dict[str, torch.Tensor] = {} def load_model(self, *args, **kwargs): pass @@ -1290,6 +1294,52 @@ def stop_profile(self): if isinstance(self.scorer_worker, WorkerBase): self.scorer_worker.stop_profile() + def sleep(self, level=1): + free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + + # Save the buffers before level 2 sleep + if level == 2: + scorer_model = self.scorer_worker.get_model().model + self._scorer_sleep_saved_buffer = { + name: buffer.cpu().clone() + for name, buffer in scorer_model.named_buffers() + } + proposer_model = self.proposer_worker.get_model().model + self._proposer_sleep_saved_buffer = { + name: buffer.cpu().clone() + for name, buffer in proposer_model.named_buffers() + } + allocator = CuMemAllocator.get_instance() + allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) + free_bytes_after_sleep, total = torch.cuda.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode freed %.2f GiB memory, " + "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, + used_bytes / GiB_bytes) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + allocator = CuMemAllocator.get_instance() + allocator.wake_up(tags=tags) + + # Restore the buffers after level 2 sleep + if len(self._scorer_sleep_saved_buffer): + model = self.scorer_worker.get_model().model + for name, buffer in model.named_buffers(): + if name in self._scorer_sleep_saved_buffer: + buffer.data.copy_( + self._scorer_sleep_saved_buffer[name].data) + self._scorer_sleep_saved_buffer = {} + if len(self._proposer_sleep_saved_buffer): + model = self.proposer_worker.get_model().model + for name, buffer in model.named_buffers(): + if name in self._proposer_sleep_saved_buffer: + buffer.data.copy_( + self._proposer_sleep_saved_buffer[name].data) + self._proposer_sleep_saved_buffer = {} + def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, proposer_cache_block_size_bytes: int,