|
10 | 10 | import torch.nn as nn
|
11 | 11 |
|
12 | 12 | from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
|
| 13 | +from vllm.device_allocator.cumem import CuMemAllocator |
13 | 14 | from vllm.distributed.communication_op import (broadcast_tensor_dict,
|
14 | 15 | get_tp_group,
|
15 | 16 | tensor_model_parallel_gather)
|
|
27 | 28 | HiddenStates, SequenceGroupMetadata,
|
28 | 29 | get_all_seq_ids_and_request_ids)
|
29 | 30 | from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
| 31 | +from vllm.utils import GiB_bytes |
30 | 32 |
|
31 | 33 | if current_platform.is_cuda_alike():
|
32 | 34 | from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
@@ -388,6 +390,8 @@ def init_device(self) -> None:
|
388 | 390 | vocab_size=self._vocab_size)
|
389 | 391 |
|
390 | 392 | self._configure_model_sampler_for_spec_decode()
|
| 393 | + self._scorer_sleep_saved_buffer = {} |
| 394 | + self._proposer_sleep_saved_buffer = {} |
391 | 395 |
|
392 | 396 | def load_model(self, *args, **kwargs):
|
393 | 397 | pass
|
@@ -1290,6 +1294,52 @@ def stop_profile(self):
|
1290 | 1294 | if isinstance(self.scorer_worker, WorkerBase):
|
1291 | 1295 | self.scorer_worker.stop_profile()
|
1292 | 1296 |
|
| 1297 | + def sleep(self, level=1): |
| 1298 | + free_bytes_before_sleep = torch.cuda.mem_get_info()[0] |
| 1299 | + |
| 1300 | + # Save the buffers before level 2 sleep |
| 1301 | + if level == 2: |
| 1302 | + scorer_model = self.scorer_worker.get_model().model |
| 1303 | + self._scorer_sleep_saved_buffer = { |
| 1304 | + name: buffer.cpu().clone() |
| 1305 | + for name, buffer in scorer_model.named_buffers() |
| 1306 | + } |
| 1307 | + proposer_model = self.proposer_worker.get_model().model |
| 1308 | + self._proposer_sleep_saved_buffer = { |
| 1309 | + name: buffer.cpu().clone() |
| 1310 | + for name, buffer in proposer_model.named_buffers() |
| 1311 | + } |
| 1312 | + allocator = CuMemAllocator.get_instance() |
| 1313 | + allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) |
| 1314 | + free_bytes_after_sleep, total = torch.cuda.mem_get_info() |
| 1315 | + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep |
| 1316 | + used_bytes = total - free_bytes_after_sleep |
| 1317 | + assert freed_bytes >= 0, "Memory usage increased after sleeping." |
| 1318 | + logger.info( |
| 1319 | + "Sleep mode freed %.2f GiB memory, " |
| 1320 | + "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, |
| 1321 | + used_bytes / GiB_bytes) |
| 1322 | + |
| 1323 | + def wake_up(self, tags: Optional[list[str]] = None) -> None: |
| 1324 | + allocator = CuMemAllocator.get_instance() |
| 1325 | + allocator.wake_up(tags=tags) |
| 1326 | + |
| 1327 | + # Restore the buffers after level 2 sleep |
| 1328 | + if len(self._scorer_sleep_saved_buffer): |
| 1329 | + model = self.scorer_worker.get_model().model |
| 1330 | + for name, buffer in model.named_buffers(): |
| 1331 | + if name in self._scorer_sleep_saved_buffer: |
| 1332 | + buffer.data.copy_( |
| 1333 | + self._scorer_sleep_saved_buffer[name].data) |
| 1334 | + self._scorer_sleep_saved_buffer = {} |
| 1335 | + if len(self._proposer_sleep_saved_buffer): |
| 1336 | + model = self.proposer_worker.get_model().model |
| 1337 | + for name, buffer in model.named_buffers(): |
| 1338 | + if name in self._proposer_sleep_saved_buffer: |
| 1339 | + buffer.data.copy_( |
| 1340 | + self._proposer_sleep_saved_buffer[name].data) |
| 1341 | + self._proposer_sleep_saved_buffer = {} |
| 1342 | + |
1293 | 1343 |
|
1294 | 1344 | def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
|
1295 | 1345 | proposer_cache_block_size_bytes: int,
|
|
0 commit comments