@@ -304,25 +304,24 @@ def model_input_split_v1_attn(
304
304
# the attn_mla kernel in torch npu only accept 128*128 attn mask
305
305
attn_mask_pre = attn_mask_post = attn_metadata .attn_mask
306
306
attn_state_pre = attn_state_post = attn_metadata .attn_state
307
-
308
307
elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
309
308
# should be none in decode only state
310
309
attn_mask_pre = attn_mask_post = attn_metadata .attn_mask
311
- attn_state_pre = attn_state_post = AscendAttentionState .DecodeOnly # noqa
310
+ attn_state_pre = attn_state_post = AscendAttentionState .DecodeOnly # type: ignore
312
311
else :
313
312
# chunked prefill
314
313
assert attn_metadata .attn_mask is not None
315
314
if has_prefill_pre :
316
- attn_state_pre = attn_state_post = AscendAttentionState .ChunkedPrefill # noqa
315
+ attn_state_pre = attn_state_post = AscendAttentionState .ChunkedPrefill # type: ignore
317
316
attn_mask_pre = attn_metadata .attn_mask [:token_index , :max (
318
317
seq_lens_pre )].contiguous ()
319
- attn_state_post = AscendAttentionState .ChunkedPrefill # noqa
318
+ attn_state_post = AscendAttentionState .ChunkedPrefill # type: ignore
320
319
attn_mask_post = attn_metadata .attn_mask [
321
320
token_index :, :max (seq_lens_post )].contiguous ()
322
321
else :
323
- attn_state_pre = AscendAttentionState .DecodeOnly # noqa
322
+ attn_state_pre = AscendAttentionState .DecodeOnly # type: ignore
324
323
attn_mask_pre = None
325
- attn_state_post = AscendAttentionState .ChunkedPrefill # noqa
324
+ attn_state_post = AscendAttentionState .ChunkedPrefill # type: ignore
326
325
attn_mask_post = attn_metadata .attn_mask [
327
326
token_index :, :max (seq_lens_post )].contiguous ()
328
327
0 commit comments