Skip to content

Commit 71de52d

Browse files
authored
feat: add kv cache memory cache and skip dynamo guard (#1549)
### What this PR does / why we need it? 1、Sometimes loading torchair cache will fail because of the floating of npu memory, so this pr add a new cache to save the old kv cache bytes to avoid the possible crash while loading the torchair graph cache. 2、When caching is enabled and does not exist, the first compilation introduces the overhead of Dynamo Gurad. So in this case, we will compile them directly twice to skip them (This will bring 3-4 ms of tpot optimization) ### Does this PR introduce _any_ user-facing change? Add a new env `VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE` to control kv cache floating tolerance ### How was this patch tested? - vLLM version: v0.9.1 - vLLM main: vllm-project/vllm@1fd471e Signed-off-by: boying <897013703@qq.com>
1 parent df84cce commit 71de52d

File tree

5 files changed

+182
-24
lines changed

5 files changed

+182
-24
lines changed

tests/ut/test_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,27 @@ def test_update_aclgraph_sizes(self):
280280
3,
281281
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
282282

283+
def test_get_torchair_current_work_dir(self):
284+
cache_dir = utils.TORCHAIR_CACHE_DIR
285+
work_dir = utils.get_torchair_current_work_dir()
286+
self.assertEqual(cache_dir, work_dir)
287+
work_dir = utils.get_torchair_current_work_dir("test")
288+
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
289+
290+
def test_torchair_cache_dir(self):
291+
utils.write_kv_cache_bytes_to_file(0, 100)
292+
self.assertTrue(utils.check_torchair_cache_exist(),
293+
"Create torchair cache dir failed")
294+
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
295+
"Create kv cache bytes cache dir failed")
296+
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
297+
self.assertEqual(100, kv_cache_bytes)
298+
utils.delete_torchair_cache_file()
299+
self.assertFalse(utils.check_torchair_cache_exist(),
300+
"Delete torchair cache dir failed")
301+
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
302+
"Delete kv cache bytes cache dir failed")
303+
283304

284305
class TestProfileExecuteDuration(unittest.TestCase):
285306

vllm_ascend/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@
121121
# value to False to disable the optimized model.
122122
"USE_OPTIMIZED_MODEL":
123123
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
124+
# The tolerance of the kv cache size, if the difference between the
125+
# actual kv cache size and the cached kv cache size is less than this value,
126+
# then the cached kv cache size will be used.
127+
"VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
128+
lambda: int(
129+
os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
124130
}
125131

126132
# end-env-vars-definition

vllm_ascend/utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
#
1919

2020
import atexit
21+
import fcntl
2122
import math
23+
import os
24+
import shutil
2225
from contextlib import contextmanager, nullcontext
2326
from enum import Enum
2427
from threading import Lock
@@ -440,3 +443,77 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool,
440443
return FusedMoEState.All2All
441444
else:
442445
return FusedMoEState.MC2
446+
447+
448+
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
449+
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
450+
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
451+
TORCHAIR_CACHE_DIR = os.getenv(
452+
'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME))
453+
454+
455+
def get_torchair_current_work_dir(file_name=None):
456+
if file_name is None:
457+
return TORCHAIR_CACHE_DIR
458+
return os.path.join(TORCHAIR_CACHE_DIR, file_name)
459+
460+
461+
def check_torchair_cache_exist():
462+
res = False
463+
torch_air_abs_path = get_torchair_current_work_dir()
464+
if os.path.exists(torch_air_abs_path):
465+
file_list = os.listdir(torch_air_abs_path)
466+
if len(file_list) != 0:
467+
res = True
468+
return res
469+
470+
471+
def check_kv_cache_bytes_cache_exist():
472+
res = False
473+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
474+
KV_CACHE_BYTES_CACHE_PATH_NAME)
475+
if os.path.exists(kv_cache_bytes_cache_abs_path):
476+
file_list = os.listdir(kv_cache_bytes_cache_abs_path)
477+
if len(file_list) != 0:
478+
res = True
479+
return res
480+
481+
482+
def read_kv_cache_bytes_from_file(rank) -> int:
483+
kv_cache_bytes = -1
484+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
485+
KV_CACHE_BYTES_CACHE_PATH_NAME)
486+
kv_cache_bytes_file = os.path.join(
487+
kv_cache_bytes_cache_abs_path,
488+
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
489+
with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
490+
with file_lock(f, fcntl.LOCK_SH):
491+
kv_cache_bytes = int(f.readline())
492+
return kv_cache_bytes
493+
494+
495+
@contextmanager
496+
def file_lock(file_descriptor, lock_type):
497+
fcntl.flock(file_descriptor, lock_type)
498+
try:
499+
yield
500+
finally:
501+
fcntl.flock(file_descriptor, fcntl.LOCK_UN)
502+
503+
504+
def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
505+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
506+
KV_CACHE_BYTES_CACHE_PATH_NAME)
507+
os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
508+
kv_cache_bytes_file = os.path.join(
509+
kv_cache_bytes_cache_abs_path,
510+
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
511+
with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
512+
with file_lock(f, fcntl.LOCK_EX):
513+
f.write(f"{kv_cache_bytes}")
514+
515+
516+
def delete_torchair_cache_file():
517+
torch_air_abs_path = get_torchair_current_work_dir()
518+
if os.path.exists(torch_air_abs_path):
519+
shutil.rmtree(torch_air_abs_path)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@
7676
from vllm_ascend.pool.metadata import PoolingMetadata
7777
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
7878
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
79-
ProfileExecuteDuration, is_310p,
79+
ProfileExecuteDuration,
80+
check_torchair_cache_exist, is_310p,
8081
maybe_converting_weight_acl_format,
81-
vllm_version_is)
82+
vllm_version_is, write_kv_cache_bytes_to_file)
8283
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
8384
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
8485
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
@@ -329,6 +330,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
329330
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
330331
attn_mask_len, self.dtype)
331332

333+
self.new_kv_cache_bytes = -1
332334
self.torchair_compiled_model = None # type: ignore
333335
self.torchair_compiled_models = {} # type: ignore
334336
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@@ -2274,6 +2276,20 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
22742276

22752277
return kv_cache_spec
22762278

2279+
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
2280+
# Trigger torchair graph capture for specific shapes.
2281+
# Capture the large shapes first so that the smaller shapes
2282+
# can reuse the memory pool allocated for the large shapes.
2283+
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
2284+
for _ in range(self.vllm_config.compilation_config.
2285+
cudagraph_num_of_warmups):
2286+
self._dummy_run(num_tokens,
2287+
is_compile=True,
2288+
with_prefill=False)
2289+
self._dummy_run(num_tokens, is_compile=True, with_prefill=False)
2290+
logger.info("Batchsize %d is compiled successfully: %d/%d.",
2291+
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
2292+
22772293
def capture_model(self) -> None:
22782294
start_time = time.perf_counter()
22792295
start_free_npu_memory = torch.npu.mem_get_info()[0]
@@ -2283,24 +2299,32 @@ def capture_model(self) -> None:
22832299
if self.torchair_graph_enabled:
22842300
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
22852301
graph_num = len(torchair_graph_batch_sizes)
2286-
logger.info(
2287-
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
2288-
0.5 * graph_num, 1.5 * graph_num)
2289-
# Trigger torchair graph capture for specific shapes.
2290-
# Capture the large shapes first so that the smaller shapes
2291-
# can reuse the memory pool allocated for the large shapes.
2292-
for idx, num_tokens in enumerate(
2293-
reversed(torchair_graph_batch_sizes)):
2294-
for _ in range(self.vllm_config.compilation_config.
2295-
cudagraph_num_of_warmups):
2296-
self._dummy_run(num_tokens,
2297-
is_compile=True,
2298-
with_prefill=False)
2299-
self._dummy_run(num_tokens,
2300-
is_compile=True,
2301-
with_prefill=False)
2302-
logger.info("Batchsize %d is compiled successfully: %d/%d.",
2303-
num_tokens, idx + 1, graph_num)
2302+
2303+
if self.use_cached_npu_graph and not check_torchair_cache_exist():
2304+
# If caching is enabled but does not exist, we will compile the model twice. The first
2305+
# time is used to generate the cache, and the second time is used to load the cache to
2306+
# skip the overhead caused by Dynamo guard mechanism.
2307+
logger.info(
2308+
"Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.",
2309+
0.5 * graph_num, 1.5 * graph_num)
2310+
self._compile_torchair_graph(torchair_graph_batch_sizes)
2311+
NPUPlatform.synchronize()
2312+
torch._dynamo.reset()
2313+
self.torchair_compiled_models.clear()
2314+
if self.use_cached_npu_graph:
2315+
logger.info(
2316+
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
2317+
0.3 * graph_num, 0.5 * graph_num)
2318+
self._compile_torchair_graph(torchair_graph_batch_sizes)
2319+
else:
2320+
logger.info(
2321+
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
2322+
0.5 * graph_num, 1.5 * graph_num)
2323+
self._compile_torchair_graph(torchair_graph_batch_sizes)
2324+
2325+
if self.new_kv_cache_bytes > 0:
2326+
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
2327+
self.new_kv_cache_bytes)
23042328
elif self.use_aclgraph:
23052329
# Trigger ACL graph capture for specific shapes.
23062330
# Capture the large shapes first so that the smaller shapes

vllm_ascend/worker/worker_v1.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,16 @@
3636
from vllm.v1.outputs import ModelRunnerOutput
3737
from vllm.v1.worker.worker_base import WorkerBase
3838

39-
from vllm_ascend.ascend_config import init_ascend_config
39+
import vllm_ascend.envs as envs_ascend
40+
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
4041
from vllm_ascend.device_allocator.camem import CaMemAllocator
4142
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
4243
from vllm_ascend.platform import NPUPlatform
43-
from vllm_ascend.utils import sleep_mode_enabled, try_register_lib
44+
from vllm_ascend.utils import (check_kv_cache_bytes_cache_exist,
45+
check_torchair_cache_exist,
46+
delete_torchair_cache_file,
47+
read_kv_cache_bytes_from_file,
48+
sleep_mode_enabled, try_register_lib)
4449
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
4550

4651

@@ -167,10 +172,35 @@ def determine_available_memory(self) -> int:
167172
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
168173
if non_torch_allocations > 0:
169174
peak_memory += non_torch_allocations
170-
available_kv_cache_memory = (
175+
available_kv_cache_memory = int(
171176
total_npu_memory * self.cache_config.gpu_memory_utilization -
172177
peak_memory)
173-
return int(available_kv_cache_memory)
178+
available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
179+
logger.info(
180+
f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}"
181+
)
182+
if get_ascend_config().torchair_graph_config.enabled:
183+
if check_torchair_cache_exist(
184+
) and check_kv_cache_bytes_cache_exist():
185+
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
186+
torch.distributed.get_rank())
187+
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
188+
logger.info(
189+
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
190+
)
191+
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
192+
return old_kv_cache_bytes
193+
else:
194+
logger.info(
195+
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
196+
)
197+
delete_torchair_cache_file()
198+
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
199+
available_kv_cache_memory -= bytes_floating_tolerance
200+
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
201+
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
202+
203+
return available_kv_cache_memory
174204

175205
def execute_model(
176206
self,

0 commit comments

Comments
 (0)