Skip to content

[Core] Implement sleep/wake_up for SpecDecodeWorker #20422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.nn as nn

from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed.communication_op import (broadcast_tensor_dict,
get_tp_group,
tensor_model_parallel_gather)
Expand All @@ -27,6 +28,7 @@
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.utils import GiB_bytes

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
Expand Down Expand Up @@ -388,6 +390,8 @@ def init_device(self) -> None:
vocab_size=self._vocab_size)

self._configure_model_sampler_for_spec_decode()
self._scorer_sleep_saved_buffer: Dict[str, torch.Tensor] = {}
self._proposer_sleep_saved_buffer: Dict[str, torch.Tensor] = {}

def load_model(self, *args, **kwargs):
pass
Expand Down Expand Up @@ -1290,6 +1294,52 @@ def stop_profile(self):
if isinstance(self.scorer_worker, WorkerBase):
self.scorer_worker.stop_profile()

def sleep(self, level=1):
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]

# Save the buffers before level 2 sleep
if level == 2:
scorer_model = self.scorer_worker.get_model().model
self._scorer_sleep_saved_buffer = {
name: buffer.cpu().clone()
for name, buffer in scorer_model.named_buffers()
}
proposer_model = self.proposer_worker.get_model().model
self._proposer_sleep_saved_buffer = {
name: buffer.cpu().clone()
for name, buffer in proposer_model.named_buffers()
}
allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, "
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)

def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags=tags)

# Restore the buffers after level 2 sleep
if len(self._scorer_sleep_saved_buffer):
model = self.scorer_worker.get_model().model
for name, buffer in model.named_buffers():
if name in self._scorer_sleep_saved_buffer:
buffer.data.copy_(
self._scorer_sleep_saved_buffer[name].data)
self._scorer_sleep_saved_buffer = {}
if len(self._proposer_sleep_saved_buffer):
model = self.proposer_worker.get_model().model
for name, buffer in model.named_buffers():
if name in self._proposer_sleep_saved_buffer:
buffer.data.copy_(
self._proposer_sleep_saved_buffer[name].data)
self._proposer_sleep_saved_buffer = {}


def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
proposer_cache_block_size_bytes: int,
Expand Down