Skip to content

Commit 59389c9

Browse files
authored
[BugFix][CPU] Fix CPU worker dependency on cumem_allocator (#20696)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 8f2720d commit 59389c9

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

vllm/v1/worker/gpu_worker.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import vllm.envs as envs
1313
from vllm.config import VllmConfig
14-
from vllm.device_allocator.cumem import CuMemAllocator
1514
from vllm.distributed import (ensure_model_parallel_initialized,
1615
init_distributed_environment,
1716
set_custom_all_reduce)
@@ -79,6 +78,8 @@ def __init__(
7978
self.profiler = None
8079

8180
def sleep(self, level: int = 1) -> None:
81+
from vllm.device_allocator.cumem import CuMemAllocator
82+
8283
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
8384

8485
# Save the buffers before level 2 sleep
@@ -101,6 +102,8 @@ def sleep(self, level: int = 1) -> None:
101102
used_bytes / GiB_bytes)
102103

103104
def wake_up(self, tags: Optional[list[str]] = None) -> None:
105+
from vllm.device_allocator.cumem import CuMemAllocator
106+
104107
allocator = CuMemAllocator.get_instance()
105108
allocator.wake_up(tags)
106109

@@ -174,6 +177,8 @@ def init_device(self):
174177
# to hijack tensor allocation.
175178
def load_model(self) -> None:
176179
if self.vllm_config.model_config.enable_sleep_mode:
180+
from vllm.device_allocator.cumem import CuMemAllocator
181+
177182
allocator = CuMemAllocator.get_instance()
178183
assert allocator.get_current_usage() == 0, (
179184
"Sleep mode can only be "
@@ -241,7 +246,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
241246

242247
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
243248
"""Allocate GPU KV cache with the specified kv_cache_config."""
249+
244250
if self.vllm_config.model_config.enable_sleep_mode:
251+
from vllm.device_allocator.cumem import CuMemAllocator
252+
245253
allocator = CuMemAllocator.get_instance()
246254
context = allocator.use_memory_pool(tag="kv_cache")
247255
else:

0 commit comments

Comments
 (0)