Skip to content

Commit 0d90c24

Browse files
committed
fix: manage KV cache buffer lifecycle to prevent premature deallocation
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent 5ff5a06 commit 0d90c24

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

vllm_ascend/distributed/llmdatadist_connector_v1.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,22 @@ def __init__(self, vllm_config: "VllmConfig",
454454
else:
455455
logger.warning(f"Still {len(clusters)} clusters to link")
456456

457+
# LLMDataDist will deallocate the cache buffer either when the cache
458+
# buffer's Python object goes out of scope or when deallocate_cache() is
459+
# explicitly called. This can lead to accuracy issues if the cache
460+
# buffer is deallocated while still being used in the NPU stream. To
461+
# prevent this, we maintain a reference to the cache buffer until the
462+
# next round, ensuring it is not prematurely deallocated.
463+
self.kv_buffers: List = []
464+
465+
def _detach_kv_buffers(self):
466+
for kv_buffer in self.kv_buffers:
467+
self.llm_datadist_engine.kv_transfer.deallocate_cache(kv_buffer)
468+
self.kv_buffers.clear()
469+
470+
def _attach_kv_buffer(self, kv_buffer: torch.Tensor):
471+
self.kv_buffers.append(kv_buffer)
472+
457473
def start_load_kv(self, forward_context: "ForwardContext",
458474
**kwargs) -> None:
459475
"""
@@ -477,6 +493,9 @@ def start_load_kv(self, forward_context: "ForwardContext",
477493
# In the prefilling node, do not need to load KV cache.
478494
return
479495

496+
# Release the KV cache buffer from the previous round
497+
self._detach_kv_buffers()
498+
480499
# Get the metadata
481500
metadata = self._get_connector_metadata()
482501
assert isinstance(metadata, LLMDataDistConnectorV1Metadata)
@@ -558,6 +577,7 @@ def start_load_kv(self, forward_context: "ForwardContext",
558577
kv_hidden_dtype = kv_cache_layers[0].dtype
559578
kv_buffer, pulled_kv_caches = self._create_cache_tensors(
560579
self.num_layers, kv_cache_shape, kv_hidden_dtype)
580+
self._attach_kv_buffer(kv_buffer)
561581

562582
target_tp_rank = self.tp_rank % min(
563583
self.cluster_info.prefill_tp_size,
@@ -590,9 +610,6 @@ def start_load_kv(self, forward_context: "ForwardContext",
590610
self._inject_kv_into_layer(kv_cache_layer, pulled_kv_cache,
591611
req_slot_mapping, is_mla)
592612

593-
# Release the reference count
594-
self.llm_datadist_engine.kv_transfer.deallocate_cache(kv_buffer)
595-
596613
def wait_for_layer_load(self, layer_name: str) -> None:
597614
"""
598615
Block until the KV for a specific layer is loaded into vLLM's

0 commit comments

Comments
 (0)