Skip to content

Commit a2552e1

Browse files
authored
[Worker][V1] Support sleep mode for v1 (#1084)
### What this PR does / why we need it? Support sleep mode for v1 Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent 0395ab3 commit a2552e1

File tree

5 files changed

+65
-60
lines changed

5 files changed

+65
-60
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,13 @@ jobs:
114114
# pytest -sv tests/singlecard/test_guided_decoding.py.py
115115
# test_ascend_config.py should be ran separately because it will regenerate the global config many times.
116116
pytest -sv tests/singlecard/test_ascend_config.py
117+
pytest -sv tests/singlecard/test_camem.py
117118
pytest -sv tests/singlecard/ \
118119
--ignore=tests/singlecard/test_offline_inference.py \
119120
--ignore=tests/singlecard/test_scheduler.py \
120121
--ignore=tests/singlecard/test_guided_decoding.py \
121-
--ignore=tests/singlecard/test_ascend_config.py
122+
--ignore=tests/singlecard/test_ascend_config.py \
123+
--ignore=tests/singlecard/test_camem.py
122124
else
123125
pytest -sv tests/multicard/test_ilama_lora_tp2.py
124126
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py

tests/singlecard/test_camem.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,14 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818
#
19-
import os
2019

21-
import pytest
2220
import torch
2321
from vllm import LLM, SamplingParams
2422
from vllm.utils import GiB_bytes
2523

2624
from tests.utils import fork_new_process_for_each_test
2725
from vllm_ascend.device_allocator.camem import CaMemAllocator
2826

29-
if os.getenv("VLLM_USE_V1") == "1":
30-
pytest.skip("Skip in vllm v1", allow_module_level=True)
31-
3227

3328
@fork_new_process_for_each_test
3429
def test_basic_camem():

vllm_ascend/platform.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# This file is a part of the vllm-ascend project.
1616
#
1717

18+
import gc
1819
import logging
1920
import os
2021
from typing import TYPE_CHECKING, Optional, Tuple
@@ -118,6 +119,12 @@ def synchronize(cls):
118119
def mem_get_info(cls) -> Tuple[int, int]:
119120
return torch.npu.mem_get_info()
120121

122+
@classmethod
123+
def clear_npu_memory(cls):
124+
gc.collect()
125+
torch.npu.empty_cache()
126+
torch.npu.reset_peak_memory_stats()
127+
121128
@classmethod
122129
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
123130
# initialize ascend config from vllm additional_config

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,11 +1235,6 @@ def profile_run(self) -> None:
12351235
# assert self.lora_manager is not None, "LoRA is not enabled"
12361236
# TODO: call maybe_profile_with_lora()
12371237

1238-
dummy_kv_caches = [
1239-
torch.tensor((), dtype=torch.float32, device=self.device)
1240-
for _ in range(self.num_attn_layers)
1241-
]
1242-
12431238
# Trigger compilation for general shape.
12441239
hidden_states = self._dummy_run(self.max_num_tokens)
12451240

@@ -1250,7 +1245,7 @@ def profile_run(self) -> None:
12501245
logits = None
12511246

12521247
NPUPlatform.synchronize()
1253-
del hidden_states, logits, dummy_kv_caches
1248+
del hidden_states, logits
12541249
self.encoder_cache.clear()
12551250
gc.collect()
12561251

vllm_ascend/worker/worker_v1.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
1818
#
1919

20-
import gc
21-
from typing import Dict, List, Optional
20+
from typing import Optional
2221

2322
import torch
2423
import torch.nn as nn
@@ -33,16 +32,15 @@
3332
from vllm.logger import logger
3433
from vllm.lora.request import LoRARequest
3534
from vllm.model_executor import set_random_seed
36-
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
35+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
3736
from vllm.v1.core.sched.output import SchedulerOutput
38-
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
39-
KVCacheSpec)
37+
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
4038
from vllm.v1.outputs import ModelRunnerOutput
41-
from vllm.v1.utils import bind_kv_cache
4239
from vllm.v1.worker.worker_base import WorkerBase
4340

4441
import vllm_ascend.envs as envs_ascend
4542
from vllm_ascend.ascend_config import init_ascend_config
43+
from vllm_ascend.device_allocator.camem import CaMemAllocator
4644
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
4745
from vllm_ascend.platform import NPUPlatform
4846
from vllm_ascend.utils import try_register_lib
@@ -95,10 +93,22 @@ def __init__(
9593
self.profiler = self._init_profiler()
9694

9795
def sleep(self, level: int = 1) -> None:
98-
logger.error("Sleep mode is only supported on v0")
96+
NPUPlatform.set_device(self.device)
97+
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
98+
allocator = CaMemAllocator.get_instance()
99+
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
100+
free_bytes_after_sleep, total = NPUPlatform.mem_get_info()
101+
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
102+
used_bytes = total - free_bytes_after_sleep
103+
assert freed_bytes >= 0, "Memory usage increased after sleeping."
104+
logger.info(
105+
"Sleep mode freed %.2f GiB memory, "
106+
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
107+
used_bytes / GiB_bytes)
99108

100109
def wake_up(self, tags: Optional[list[str]] = None) -> None:
101-
logger.error("Sleep mode is only supported on v0")
110+
allocator = CaMemAllocator.get_instance()
111+
allocator.wake_up(tags=tags)
102112

103113
def init_device(self):
104114
if self.device_config.device.type == "npu":
@@ -119,58 +129,42 @@ def init_device(self):
119129
self.model_runner = NPUModelRunner(self.vllm_config, self.device)
120130

121131
def determine_available_memory(self) -> int:
122-
kv_caches: Dict[str, torch.Tensor] = {}
123-
kv_cache_spec = self.model_runner.get_kv_cache_spec()
124-
for layer_name, layer_spec in kv_cache_spec.items():
125-
if isinstance(layer_spec, FullAttentionSpec):
126-
# Use an empty tensor instead of `None`` to force Dynamo to pass
127-
# it by reference, rather by specializing on the value ``None``.
128-
npu_k_cache = torch.tensor([],
129-
dtype=layer_spec.dtype,
130-
device=self.device)
131-
npu_v_cache = torch.tensor([],
132-
dtype=layer_spec.dtype,
133-
device=self.device)
134-
kv_caches[layer_name] = (npu_k_cache, npu_v_cache)
135-
else:
136-
raise NotImplementedError
137-
138-
runner_kv_caches: List[torch.Tensor] = []
139-
bind_kv_cache(
140-
kv_caches,
141-
self.vllm_config.compilation_config.static_forward_context,
142-
runner_kv_caches)
143-
144132
# Profile the memory usage of the model and get the maximum number of
145133
# cache blocks that can be allocated with the remaining free memory.
146-
NPUPlatform.empty_cache()
134+
NPUPlatform.clear_npu_memory()
147135

148136
# Execute a forward pass with dummy inputs to profile the memory usage
149137
# of the model.
138+
_, total_npu_memory = NPUPlatform.mem_get_info()
150139
self.model_runner.profile_run()
151140

152141
# Calculate the number of blocks that can be allocated with the
153142
# profiled peak memory.
154-
free_npu_memory, total_npu_memory = NPUPlatform.mem_get_info()
143+
free_npu_memory, _ = NPUPlatform.mem_get_info()
155144
# NOTE(woosuk): Here we assume that the other processes using the same
156145
# GPU did not change their memory usage during the profiling.
157-
peak_memory = self.init_npu_memory - free_npu_memory
158-
assert peak_memory > 0, (
146+
assert self.init_npu_memory > free_npu_memory, (
159147
"Error in memory profiling. "
160148
f"Initial free memory {self.init_npu_memory}, current free memory"
161149
f" {free_npu_memory}. This happens when the NPU memory was "
162150
"not properly cleaned up before initializing the vLLM instance.")
163151

164-
gc.collect()
152+
# Get the peak memory allocation recorded by torch
153+
peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"]
165154
# TODO: don`t need impl this func after empty_cache in
166155
# Worker.determine_num_available_blocks() unified`
167156
NPUPlatform.empty_cache()
168-
usable_memory_size = total_npu_memory * self.cache_config.gpu_memory_utilization - peak_memory
169-
npu_kv_cache_bytes = max(usable_memory_size, 0)
170-
logger.info(
171-
f"Available memory: {usable_memory_size}, total memory: {total_npu_memory}"
172-
)
173-
return int(npu_kv_cache_bytes)
157+
torch_allocated_bytes = torch_npu.npu.memory_stats(
158+
)["allocated_bytes.all.current"]
159+
total_allocated_bytes = torch_npu.npu.mem_get_info(
160+
)[1] - torch_npu.npu.mem_get_info()[0]
161+
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
162+
if non_torch_allocations > 0:
163+
peak_memory += non_torch_allocations
164+
available_kv_cache_memory = (
165+
total_npu_memory * self.cache_config.gpu_memory_utilization -
166+
peak_memory)
167+
return int(available_kv_cache_memory)
174168

175169
def execute_model(
176170
self,
@@ -180,7 +174,17 @@ def execute_model(
180174
return output if self.is_driver_worker else None
181175

182176
def load_model(self) -> None:
183-
self.model_runner.load_model()
177+
if self.vllm_config.model_config.enable_sleep_mode:
178+
allocator = CaMemAllocator.get_instance()
179+
assert allocator.get_current_usage() == 0, (
180+
"Sleep mode can only be "
181+
"used for one instance per process.")
182+
context = allocator.use_memory_pool(tag="weights")
183+
else:
184+
from contextlib import nullcontext
185+
context = nullcontext() # type: ignore
186+
with context:
187+
self.model_runner.load_model()
184188

185189
def compile_or_warm_up_model(self) -> None:
186190
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
@@ -206,12 +210,14 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
206210

207211
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
208212
"""Allocate NPU KV cache with the specified kv_cache_config."""
209-
self.model_runner.initialize_kv_cache(kv_cache_config)
210-
211-
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
212-
"""Allocate GPU KV cache with the specified kv_cache_config."""
213-
kv_cache_config = kv_cache_configs[self.rank]
214-
self.model_runner.initialize_kv_cache(kv_cache_config)
213+
if self.vllm_config.model_config.enable_sleep_mode:
214+
allocator = CaMemAllocator.get_instance()
215+
context = allocator.use_memory_pool(tag="kv_cache")
216+
else:
217+
from contextlib import nullcontext
218+
context = nullcontext() # type: ignore
219+
with context:
220+
self.model_runner.initialize_kv_cache(kv_cache_config)
215221

216222
def profile(self, is_start: bool = True):
217223
if self.profiler is None:

0 commit comments

Comments
 (0)