Skip to content

Commit 2351977

Browse files
authored
[QuickFix][Rope] Fix rope bug in torchair+chunk-prefill scenario (#1693)
Last rope optimization PR#1614 introduces a bug that when enable torchair with chunk-prefill, the sin/cos might be None. This PR fixes this problem. Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 279fccd commit 2351977

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -352,13 +352,13 @@ def build_torchair_graph_dummy(
352352
else:
353353
attn_state = AscendAttentionState.DecodeOnly
354354
num_decode_tokens = 1
355-
sin = torch.ones(num_reqs,
355+
sin = torch.ones(num_tokens,
356356
1,
357357
1,
358358
self.rope_dim,
359359
dtype=self.runner.dtype,
360360
device=device)
361-
cos = torch.ones(num_reqs,
361+
cos = torch.ones(num_tokens,
362362
1,
363363
1,
364364
self.rope_dim,
@@ -547,15 +547,13 @@ def build(
547547
actual_seq_q_lens = query_start_loc[1:].tolist(
548548
) + self.runner.actual_seq_q_lens[num_reqs:num_reqs +
549549
num_reqs_pad_size]
550-
cos = self.cos_cache[
551-
input_positions].unsqueeze( # type: ignore
552-
1).unsqueeze(2)
553-
sin = self.sin_cache[
554-
input_positions].unsqueeze( # type: ignore
555-
1).unsqueeze(2)
556550
else:
557551
seq_lens_list = seq_lens.tolist()
558-
cos, sin = None, None
552+
553+
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
554+
1).unsqueeze(2)
555+
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
556+
1).unsqueeze(2)
559557
mc2_mask = self.generate_activate_mask(
560558
num_actual_tokens, num_reqs + num_reqs_pad_size)
561559

0 commit comments

Comments
 (0)