@@ -454,6 +454,22 @@ def __init__(self, vllm_config: "VllmConfig",
454
454
else :
455
455
logger .warning (f"Still { len (clusters )} clusters to link" )
456
456
457
+ # LLMDataDist will deallocate the cache buffer either when the cache
458
+ # buffer's Python object goes out of scope or when deallocate_cache() is
459
+ # explicitly called. This can lead to accuracy issues if the cache
460
+ # buffer is deallocated while still being used in the NPU stream. To
461
+ # prevent this, we maintain a reference to the cache buffer until the
462
+ # next round, ensuring it is not prematurely deallocated.
463
+ self .kv_buffers : List = []
464
+
465
+ def _detach_kv_buffers (self ):
466
+ for kv_buffer in self .kv_buffers :
467
+ self .llm_datadist_engine .kv_transfer .deallocate_cache (kv_buffer )
468
+ self .kv_buffers .clear ()
469
+
470
+ def _attach_kv_buffer (self , kv_buffer : torch .Tensor ):
471
+ self .kv_buffers .append (kv_buffer )
472
+
457
473
def start_load_kv (self , forward_context : "ForwardContext" ,
458
474
** kwargs ) -> None :
459
475
"""
@@ -477,6 +493,9 @@ def start_load_kv(self, forward_context: "ForwardContext",
477
493
# In the prefilling node, do not need to load KV cache.
478
494
return
479
495
496
+ # Release the KV cache buffer from the previous round
497
+ self ._detach_kv_buffers ()
498
+
480
499
# Get the metadata
481
500
metadata = self ._get_connector_metadata ()
482
501
assert isinstance (metadata , LLMDataDistConnectorV1Metadata )
@@ -558,6 +577,7 @@ def start_load_kv(self, forward_context: "ForwardContext",
558
577
kv_hidden_dtype = kv_cache_layers [0 ].dtype
559
578
kv_buffer , pulled_kv_caches = self ._create_cache_tensors (
560
579
self .num_layers , kv_cache_shape , kv_hidden_dtype )
580
+ self ._attach_kv_buffer (kv_buffer )
561
581
562
582
target_tp_rank = self .tp_rank % min (
563
583
self .cluster_info .prefill_tp_size ,
@@ -590,9 +610,6 @@ def start_load_kv(self, forward_context: "ForwardContext",
590
610
self ._inject_kv_into_layer (kv_cache_layer , pulled_kv_cache ,
591
611
req_slot_mapping , is_mla )
592
612
593
- # Release the reference count
594
- self .llm_datadist_engine .kv_transfer .deallocate_cache (kv_buffer )
595
-
596
613
def wait_for_layer_load (self , layer_name : str ) -> None :
597
614
"""
598
615
Block until the KV for a specific layer is loaded into vLLM's
0 commit comments