17
17
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
18
18
#
19
19
20
- import gc
21
- from typing import Dict , List , Optional
20
+ from typing import Optional
22
21
23
22
import torch
24
23
import torch .nn as nn
33
32
from vllm .logger import logger
34
33
from vllm .lora .request import LoRARequest
35
34
from vllm .model_executor import set_random_seed
36
- from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE
35
+ from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE , GiB_bytes
37
36
from vllm .v1 .core .sched .output import SchedulerOutput
38
- from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
39
- KVCacheSpec )
37
+ from vllm .v1 .kv_cache_interface import KVCacheConfig , KVCacheSpec
40
38
from vllm .v1 .outputs import ModelRunnerOutput
41
- from vllm .v1 .utils import bind_kv_cache
42
39
from vllm .v1 .worker .worker_base import WorkerBase
43
40
44
41
import vllm_ascend .envs as envs_ascend
45
42
from vllm_ascend .ascend_config import init_ascend_config
43
+ from vllm_ascend .device_allocator .camem import CaMemAllocator
46
44
from vllm_ascend .distributed .parallel_state import init_ascend_model_parallel
47
45
from vllm_ascend .platform import NPUPlatform
48
46
from vllm_ascend .utils import try_register_lib
@@ -95,10 +93,22 @@ def __init__(
95
93
self .profiler = self ._init_profiler ()
96
94
97
95
def sleep (self , level : int = 1 ) -> None :
98
- logger .error ("Sleep mode is only supported on v0" )
96
+ NPUPlatform .set_device (self .device )
97
+ free_bytes_before_sleep = NPUPlatform .mem_get_info ()[0 ]
98
+ allocator = CaMemAllocator .get_instance ()
99
+ allocator .sleep (offload_tags = ("weights" , ) if level == 1 else tuple ())
100
+ free_bytes_after_sleep , total = NPUPlatform .mem_get_info ()
101
+ freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
102
+ used_bytes = total - free_bytes_after_sleep
103
+ assert freed_bytes >= 0 , "Memory usage increased after sleeping."
104
+ logger .info (
105
+ "Sleep mode freed %.2f GiB memory, "
106
+ "%.2f GiB memory is still in use." , freed_bytes / GiB_bytes ,
107
+ used_bytes / GiB_bytes )
99
108
100
109
def wake_up (self , tags : Optional [list [str ]] = None ) -> None :
101
- logger .error ("Sleep mode is only supported on v0" )
110
+ allocator = CaMemAllocator .get_instance ()
111
+ allocator .wake_up (tags = tags )
102
112
103
113
def init_device (self ):
104
114
if self .device_config .device .type == "npu" :
@@ -119,58 +129,42 @@ def init_device(self):
119
129
self .model_runner = NPUModelRunner (self .vllm_config , self .device )
120
130
121
131
def determine_available_memory (self ) -> int :
122
- kv_caches : Dict [str , torch .Tensor ] = {}
123
- kv_cache_spec = self .model_runner .get_kv_cache_spec ()
124
- for layer_name , layer_spec in kv_cache_spec .items ():
125
- if isinstance (layer_spec , FullAttentionSpec ):
126
- # Use an empty tensor instead of `None`` to force Dynamo to pass
127
- # it by reference, rather by specializing on the value ``None``.
128
- npu_k_cache = torch .tensor ([],
129
- dtype = layer_spec .dtype ,
130
- device = self .device )
131
- npu_v_cache = torch .tensor ([],
132
- dtype = layer_spec .dtype ,
133
- device = self .device )
134
- kv_caches [layer_name ] = (npu_k_cache , npu_v_cache )
135
- else :
136
- raise NotImplementedError
137
-
138
- runner_kv_caches : List [torch .Tensor ] = []
139
- bind_kv_cache (
140
- kv_caches ,
141
- self .vllm_config .compilation_config .static_forward_context ,
142
- runner_kv_caches )
143
-
144
132
# Profile the memory usage of the model and get the maximum number of
145
133
# cache blocks that can be allocated with the remaining free memory.
146
- NPUPlatform .empty_cache ()
134
+ NPUPlatform .clear_npu_memory ()
147
135
148
136
# Execute a forward pass with dummy inputs to profile the memory usage
149
137
# of the model.
138
+ _ , total_npu_memory = NPUPlatform .mem_get_info ()
150
139
self .model_runner .profile_run ()
151
140
152
141
# Calculate the number of blocks that can be allocated with the
153
142
# profiled peak memory.
154
- free_npu_memory , total_npu_memory = NPUPlatform .mem_get_info ()
143
+ free_npu_memory , _ = NPUPlatform .mem_get_info ()
155
144
# NOTE(woosuk): Here we assume that the other processes using the same
156
145
# GPU did not change their memory usage during the profiling.
157
- peak_memory = self .init_npu_memory - free_npu_memory
158
- assert peak_memory > 0 , (
146
+ assert self .init_npu_memory > free_npu_memory , (
159
147
"Error in memory profiling. "
160
148
f"Initial free memory { self .init_npu_memory } , current free memory"
161
149
f" { free_npu_memory } . This happens when the NPU memory was "
162
150
"not properly cleaned up before initializing the vLLM instance." )
163
151
164
- gc .collect ()
152
+ # Get the peak memory allocation recorded by torch
153
+ peak_memory = torch_npu .npu .memory_stats ()["allocated_bytes.all.peak" ]
165
154
# TODO: don`t need impl this func after empty_cache in
166
155
# Worker.determine_num_available_blocks() unified`
167
156
NPUPlatform .empty_cache ()
168
- usable_memory_size = total_npu_memory * self .cache_config .gpu_memory_utilization - peak_memory
169
- npu_kv_cache_bytes = max (usable_memory_size , 0 )
170
- logger .info (
171
- f"Available memory: { usable_memory_size } , total memory: { total_npu_memory } "
172
- )
173
- return int (npu_kv_cache_bytes )
157
+ torch_allocated_bytes = torch_npu .npu .memory_stats (
158
+ )["allocated_bytes.all.current" ]
159
+ total_allocated_bytes = torch_npu .npu .mem_get_info (
160
+ )[1 ] - torch_npu .npu .mem_get_info ()[0 ]
161
+ non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
162
+ if non_torch_allocations > 0 :
163
+ peak_memory += non_torch_allocations
164
+ available_kv_cache_memory = (
165
+ total_npu_memory * self .cache_config .gpu_memory_utilization -
166
+ peak_memory )
167
+ return int (available_kv_cache_memory )
174
168
175
169
def execute_model (
176
170
self ,
@@ -180,7 +174,17 @@ def execute_model(
180
174
return output if self .is_driver_worker else None
181
175
182
176
def load_model (self ) -> None :
183
- self .model_runner .load_model ()
177
+ if self .vllm_config .model_config .enable_sleep_mode :
178
+ allocator = CaMemAllocator .get_instance ()
179
+ assert allocator .get_current_usage () == 0 , (
180
+ "Sleep mode can only be "
181
+ "used for one instance per process." )
182
+ context = allocator .use_memory_pool (tag = "weights" )
183
+ else :
184
+ from contextlib import nullcontext
185
+ context = nullcontext () # type: ignore
186
+ with context :
187
+ self .model_runner .load_model ()
184
188
185
189
def compile_or_warm_up_model (self ) -> None :
186
190
warmup_sizes = self .vllm_config .compilation_config .compile_sizes .copy ()
@@ -206,12 +210,14 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
206
210
207
211
def initialize_from_config (self , kv_cache_config : KVCacheConfig ) -> None :
208
212
"""Allocate NPU KV cache with the specified kv_cache_config."""
209
- self .model_runner .initialize_kv_cache (kv_cache_config )
210
-
211
- def initialize_cache (self , kv_cache_configs : List [KVCacheConfig ]) -> None :
212
- """Allocate GPU KV cache with the specified kv_cache_config."""
213
- kv_cache_config = kv_cache_configs [self .rank ]
214
- self .model_runner .initialize_kv_cache (kv_cache_config )
213
+ if self .vllm_config .model_config .enable_sleep_mode :
214
+ allocator = CaMemAllocator .get_instance ()
215
+ context = allocator .use_memory_pool (tag = "kv_cache" )
216
+ else :
217
+ from contextlib import nullcontext
218
+ context = nullcontext () # type: ignore
219
+ with context :
220
+ self .model_runner .initialize_kv_cache (kv_cache_config )
215
221
216
222
def profile (self , is_start : bool = True ):
217
223
if self .profiler is None :
0 commit comments