Skip to content

Commit 129a472

Browse files
authored
[0.9.1][bugfix] fix deepseek memory bug (#1551)
### What this PR does / why we need it? fix OOM error when `chunked_prefill_for_mla` is enable and long input scene. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent 03bb288 commit 129a472

File tree

3 files changed

+18
-25
lines changed

3 files changed

+18
-25
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,9 +1116,7 @@ def forward(
11161116
else:
11171117
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
11181118
attn_metadata.decode.input_positions,
1119-
decode_q_pe.contiguous(),
1120-
decode_k_pe,
1121-
max_seq_len=attn_metadata.decode.max_seq_lens)
1119+
decode_q_pe.contiguous(), decode_k_pe)
11221120
if has_prefill:
11231121
assert attn_metadata.prefill is not None
11241122
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
@@ -1150,9 +1148,7 @@ def forward(
11501148
else:
11511149
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
11521150
attn_metadata.prefill.input_positions,
1153-
prefill_q_pe.contiguous(),
1154-
prefill_k_pe,
1155-
max_seq_len=attn_metadata.prefill.max_seq_lens)
1151+
prefill_q_pe.contiguous(), prefill_k_pe)
11561152

11571153
assert len(
11581154
kv_cache

vllm_ascend/ops/rotary_embedding.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,7 @@ def native_rope_deepseek_forward(self,
8080
positions: torch.Tensor,
8181
query: torch.Tensor,
8282
key: torch.Tensor,
83-
offsets: Optional[torch.Tensor] = None,
84-
max_seq_len: Optional[int] = None):
85-
if max_seq_len is not None and max_seq_len > self.max_seq_len:
86-
_set_cos_sin_cache(self, max_seq_len, query.device, query.dtype)
83+
offsets: Optional[torch.Tensor] = None):
8784
if len(key.shape) == 2:
8885
key = key[:, None, :]
8986
# Note: we implement the non neox_style method with shuffle the last dim and neox style
@@ -198,8 +195,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
198195
return q_embed, k_embed
199196

200197

201-
def _set_cos_sin_cache(self, seq_len, device, dtype):
202-
self.max_seq_len_cached = seq_len
198+
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
203199
dim = self.rotary_dim
204200

205201
freq_extra = 1.0 / (self.base**(
@@ -219,9 +215,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
219215
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
220216
self.register_buffer("inv_freq", inv_freq, persistent=False)
221217

222-
t = torch.arange(seq_len * self.scaling_factor,
223-
device=device,
224-
dtype=torch.float32)
218+
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
225219

226220
freqs = torch.outer(t, inv_freq)
227221
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
@@ -266,11 +260,10 @@ def deepseek_rope_init_func(
266260
super(DeepseekScalingRotaryEmbedding,
267261
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
268262
is_neox_style, dtype)
269-
self.max_seq_len = max_position_embeddings
270-
_set_cos_sin_cache(self,
271-
max_position_embeddings,
272-
dtype=dtype,
273-
device="npu")
263+
264+
# NOTE: For ascend friendly computing, reorder sin and cos cache
265+
self.max_seq_len = max_position_embeddings * scaling_factor
266+
_set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu")
274267

275268

276269
RotaryEmbedding.forward_oot = rope_forward_oot

vllm_ascend/worker/model_runner_v1.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
350350
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
351351
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
352352
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
353+
self.use_ring_mla = ascend_config.chunked_prefill_for_mla
353354

354355
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
355356
self.init_torchair_graph_batch_sizes()
@@ -913,11 +914,14 @@ def _process_reqs(
913914
else:
914915
attn_state = AscendAttentionState.PrefillCacheHit
915916

916-
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
917-
query_lens=num_scheduled_tokens,
918-
position=positions,
919-
attn_state=attn_state)
920-
self.attn_mask = attn_mask
917+
# NOTE: when use ring_mla, attn_mask don't need to generate here.
918+
if not self.use_ring_mla or attn_state == AscendAttentionState.PrefillNoCache:
919+
attn_mask = self._make_attention_mask(
920+
seq_lens=seq_lens,
921+
query_lens=num_scheduled_tokens,
922+
position=positions,
923+
attn_state=attn_state)
924+
self.attn_mask = attn_mask
921925
self.attn_state = attn_state # type: ignore
922926

923927
extra_builder_kwargs = {}

0 commit comments

Comments
 (0)