Skip to content

Commit 5c2b440

Browse files
committed
feat: resolve npu_reshape_and_cache error
Ensure correct input for npu_reshape_and_cache function The 'slot_indices' parameter of npu_reshape_and_cache must be: - A torch.int32 tensor - Located on the NPU device Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent 7fb466d commit 5c2b440

File tree

1 file changed

+18
-36
lines changed

1 file changed

+18
-36
lines changed

vllm_ascend/distributed/llmdatadist_connector_v1.py

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,9 @@ def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
334334
block_size: int, is_store: bool) -> "ReqMeta":
335335
token_ids_tensor = torch.tensor(token_ids)
336336
valid_num_tokens = len(token_ids)
337-
block_ids_tensor = torch.tensor(block_ids)
337+
block_ids_tensor = torch.tensor(block_ids, dtype=torch.int32)
338338
num_blocks = block_ids_tensor.shape[0]
339-
block_offsets = torch.arange(0, block_size)
339+
block_offsets = torch.arange(0, block_size, dtype=torch.int32)
340340
slot_mapping = block_offsets.reshape(
341341
(1, block_size)) + block_ids_tensor.reshape(
342342
(num_blocks, 1)) * block_size
@@ -495,7 +495,7 @@ def start_load_kv(self, forward_context: "ForwardContext",
495495
# NOTE: slen is the len of kv cache need to load for this request
496496
# in decode, request_len = prefill_prompt_len + 1
497497
slen = request.token_ids.shape[0] - 1
498-
cur_slot_mapping = request.slot_mapping[:slen]
498+
req_slot_mapping = request.slot_mapping[:slen].to(device="npu")
499499

500500
# For the datadist tensor, the first dimension is 1, the reason can
501501
# be found in wait_for_save function
@@ -545,7 +545,7 @@ def start_load_kv(self, forward_context: "ForwardContext",
545545
for layer_id, kv_cache_layer in enumerate(kv_cache_layers):
546546
pulled_kv_cache = pulled_kv_caches[layer_id]
547547
self._inject_kv_into_layer(kv_cache_layer, pulled_kv_cache,
548-
cur_slot_mapping, is_mla)
548+
req_slot_mapping, is_mla)
549549

550550
# Release the reference count
551551
self.llm_datadist_engine.kv_transfer.deallocate_cache(kv_buffer)
@@ -686,42 +686,24 @@ def _inject_kv_into_layer(
686686
slot_mapping (torch.Tensor): the slot mapping. In shape
687687
[num_tokens].
688688
"""
689-
# NOTE: The performance of this function is suboptimal. Using
690-
# `torch_npu._npu_reshape_and_cache` or
691-
# `torch_npu._npu_reshape_and_cache_siso` could improve performance
692-
# significantly. However, attempts to use these methods have failed, and
693-
# the root cause remains unclear. The only available information is an
694-
# error log from the ATB log file, which states:
695-
# "ReshapeAndCacheOperation_1 invalid param, setup check fail, error
696-
# code: 13."
697-
698-
# The pulled KV cache resides in the mbuf memory space and cannot be
699-
# directly copied to the kv_cache_layer. Therefore, it must first be
700-
# copied to a standard torch tensor using `scatter_update_`.
701-
kv_cache = torch.empty_like(pulled_kv_cache)
702-
indices = torch.tensor([0], dtype=torch.int64, device="npu")
703-
torch_npu.scatter_update_(kv_cache, indices, pulled_kv_cache, axis=-2)
704689
# The `wait_for_save` function explains why the first dimension is
705690
# necessary.
706-
kv_cache = kv_cache.squeeze(0)
707-
708-
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
691+
kv_cache = pulled_kv_cache.squeeze(0)
709692
if is_mla:
710-
block_size = dst_kv_cache_layer_shape[1]
711-
num_heads = dst_kv_cache_layer_shape[2]
712-
head_dim = dst_kv_cache_layer_shape[3]
713-
idx_for_copy = slot_mapping // block_size * block_size + slot_mapping % block_size
714-
dst_kv_cache_layer = dst_kv_cache_layer.view(
715-
-1, num_heads, head_dim)
716-
dst_kv_cache_layer[idx_for_copy, ...] = kv_cache
693+
torch_npu._npu_reshape_and_cache_siso(
694+
key=kv_cache,
695+
key_cache=dst_kv_cache_layer,
696+
slot_indices=slot_mapping,
697+
)
698+
717699
else:
718-
block_size = dst_kv_cache_layer_shape[2]
719-
num_heads = dst_kv_cache_layer_shape[3]
720-
head_dim = dst_kv_cache_layer_shape[4]
721-
idx_for_copy = slot_mapping // block_size * block_size + slot_mapping % block_size
722-
dst_kv_cache_layer = dst_kv_cache_layer.view(
723-
2, -1, num_heads, head_dim)
724-
dst_kv_cache_layer[:, idx_for_copy, ...] = kv_cache
700+
torch_npu._npu_reshape_and_cache(
701+
key=kv_cache[0],
702+
value=kv_cache[1],
703+
key_cache=dst_kv_cache_layer[0],
704+
value_cache=dst_kv_cache_layer[1],
705+
slot_indices=slot_mapping,
706+
)
725707

726708
def _extract_kv_from_layer(
727709
self,

0 commit comments

Comments
 (0)