Skip to content

Commit eaed83d

Browse files
author
weijinqian
committed
handle clean code
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
1 parent e4f1050 commit eaed83d

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

vllm_ascend/multistream/ms_split.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,21 +308,21 @@ def model_input_split_v1_attn(
308308
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
309309
# should be none in decode only state
310310
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
311-
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly
311+
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly # noqa
312312
else:
313313
# chunked prefill
314314
assert attn_metadata.attn_mask is not None
315315
if has_prefill_pre:
316-
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill
316+
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill # noqa
317317
attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(
318318
seq_lens_pre)].contiguous()
319-
attn_state_post = AscendAttentionState.ChunkedPrefill
319+
attn_state_post = AscendAttentionState.ChunkedPrefill # noqa
320320
attn_mask_post = attn_metadata.attn_mask[
321321
token_index:, :max(seq_lens_post)].contiguous()
322322
else:
323-
attn_state_pre = AscendAttentionState.DecodeOnly
323+
attn_state_pre = AscendAttentionState.DecodeOnly # noqa
324324
attn_mask_pre = None
325-
attn_state_post = AscendAttentionState.ChunkedPrefill
325+
attn_state_post = AscendAttentionState.ChunkedPrefill # noqa
326326
attn_mask_post = attn_metadata.attn_mask[
327327
token_index:, :max(seq_lens_post)].contiguous()
328328

vllm_ascend/ops/moe_dispatcher/token_dispatcher.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ def __init__(self, config: MoEDispatcherConfig):
201201
self.cached_global_input_tokens = None
202202
self.cached_shared_expert_output = None
203203
self.tokens_per_expert = None
204+
self.perm1_finish_event = None
205+
self.global_input_tokens_local_experts_indices = None
204206

205207
if MoEAlltoAllSeqOverLapDispatcher.overlap_stream is None:
206208
MoEAlltoAllSeqOverLapDispatcher.overlap_stream = torch.npu.Stream()
@@ -280,7 +282,7 @@ def preprocess(self,
280282
"num_global_tokens_per_local_expert must be set before operations."
281283
)
282284
self.device_sync_point = "no_sync"
283-
self.global_input_tokens_local_experts_indices: Tensor = torch.repeat_interleave(
285+
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
284286
self.expert_ids_per_ep_rank,
285287
self.num_global_tokens_per_local_expert.ravel())
286288

@@ -426,7 +428,7 @@ def preprocess_and_permtute1(self,
426428
raise ValueError(
427429
"num_global_tokens_per_local_expert must be set before operations."
428430
)
429-
self.global_input_tokens_local_experts_indices: Tensor = torch.repeat_interleave(
431+
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
430432
self.expert_ids_per_ep_rank,
431433
self.num_global_tokens_per_local_expert.ravel())
432434

@@ -462,6 +464,7 @@ def permute2(self):
462464
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
463465
self.cached_global_input_tokens,
464466
self.global_input_tokens_local_experts_indices)
467+
assert self.cached_global_input_tokens is not None
465468
self.cached_global_input_tokens.untyped_storage().resize_(0)
466469
self.cached_global_input_tokens = None
467470

0 commit comments

Comments
 (0)