@@ -334,9 +334,9 @@ def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
334
334
block_size : int , is_store : bool ) -> "ReqMeta" :
335
335
token_ids_tensor = torch .tensor (token_ids )
336
336
valid_num_tokens = len (token_ids )
337
- block_ids_tensor = torch .tensor (block_ids )
337
+ block_ids_tensor = torch .tensor (block_ids , dtype = torch . int32 )
338
338
num_blocks = block_ids_tensor .shape [0 ]
339
- block_offsets = torch .arange (0 , block_size )
339
+ block_offsets = torch .arange (0 , block_size , dtype = torch . int32 )
340
340
slot_mapping = block_offsets .reshape (
341
341
(1 , block_size )) + block_ids_tensor .reshape (
342
342
(num_blocks , 1 )) * block_size
@@ -495,7 +495,7 @@ def start_load_kv(self, forward_context: "ForwardContext",
495
495
# NOTE: slen is the len of kv cache need to load for this request
496
496
# in decode, request_len = prefill_prompt_len + 1
497
497
slen = request .token_ids .shape [0 ] - 1
498
- cur_slot_mapping = request .slot_mapping [:slen ]
498
+ req_slot_mapping = request .slot_mapping [:slen ]. to ( device = "npu" )
499
499
500
500
# For the datadist tensor, the first dimension is 1, the reason can
501
501
# be found in wait_for_save function
@@ -545,7 +545,7 @@ def start_load_kv(self, forward_context: "ForwardContext",
545
545
for layer_id , kv_cache_layer in enumerate (kv_cache_layers ):
546
546
pulled_kv_cache = pulled_kv_caches [layer_id ]
547
547
self ._inject_kv_into_layer (kv_cache_layer , pulled_kv_cache ,
548
- cur_slot_mapping , is_mla )
548
+ req_slot_mapping , is_mla )
549
549
550
550
# Release the reference count
551
551
self .llm_datadist_engine .kv_transfer .deallocate_cache (kv_buffer )
@@ -686,42 +686,24 @@ def _inject_kv_into_layer(
686
686
slot_mapping (torch.Tensor): the slot mapping. In shape
687
687
[num_tokens].
688
688
"""
689
- # NOTE: The performance of this function is suboptimal. Using
690
- # `torch_npu._npu_reshape_and_cache` or
691
- # `torch_npu._npu_reshape_and_cache_siso` could improve performance
692
- # significantly. However, attempts to use these methods have failed, and
693
- # the root cause remains unclear. The only available information is an
694
- # error log from the ATB log file, which states:
695
- # "ReshapeAndCacheOperation_1 invalid param, setup check fail, error
696
- # code: 13."
697
-
698
- # The pulled KV cache resides in the mbuf memory space and cannot be
699
- # directly copied to the kv_cache_layer. Therefore, it must first be
700
- # copied to a standard torch tensor using `scatter_update_`.
701
- kv_cache = torch .empty_like (pulled_kv_cache )
702
- indices = torch .tensor ([0 ], dtype = torch .int64 , device = "npu" )
703
- torch_npu .scatter_update_ (kv_cache , indices , pulled_kv_cache , axis = - 2 )
704
689
# The `wait_for_save` function explains why the first dimension is
705
690
# necessary.
706
- kv_cache = kv_cache .squeeze (0 )
707
-
708
- dst_kv_cache_layer_shape = dst_kv_cache_layer .shape
691
+ kv_cache = pulled_kv_cache .squeeze (0 )
709
692
if is_mla :
710
- block_size = dst_kv_cache_layer_shape [1 ]
711
- num_heads = dst_kv_cache_layer_shape [2 ]
712
- head_dim = dst_kv_cache_layer_shape [3 ]
713
- idx_for_copy = slot_mapping // block_size * block_size + slot_mapping % block_size
714
- dst_kv_cache_layer = dst_kv_cache_layer .view (
715
- - 1 , num_heads , head_dim )
716
- dst_kv_cache_layer [idx_for_copy , ...] = kv_cache
693
+ torch_npu ._npu_reshape_and_cache_siso (
694
+ key = kv_cache ,
695
+ key_cache = dst_kv_cache_layer ,
696
+ slot_indices = slot_mapping ,
697
+ )
698
+
717
699
else :
718
- block_size = dst_kv_cache_layer_shape [ 2 ]
719
- num_heads = dst_kv_cache_layer_shape [ 3 ]
720
- head_dim = dst_kv_cache_layer_shape [ 4 ]
721
- idx_for_copy = slot_mapping // block_size * block_size + slot_mapping % block_size
722
- dst_kv_cache_layer = dst_kv_cache_layer . view (
723
- 2 , - 1 , num_heads , head_dim )
724
- dst_kv_cache_layer [:, idx_for_copy , ...] = kv_cache
700
+ torch_npu . _npu_reshape_and_cache (
701
+ key = kv_cache [ 0 ],
702
+ value = kv_cache [ 1 ],
703
+ key_cache = dst_kv_cache_layer [ 0 ],
704
+ value_cache = dst_kv_cache_layer [ 1 ],
705
+ slot_indices = slot_mapping ,
706
+ )
725
707
726
708
def _extract_kv_from_layer (
727
709
self ,
0 commit comments