Skip to content

Commit ba11765

Browse files
author
weijinqian
committed
handle clean code
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
1 parent 57081d2 commit ba11765

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

vllm_ascend/models/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ def register_model():
88
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
99
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
1010
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
11-
from .moe_block import AscendSparseMoeBlock # noqa: F401
1211
from .qwen2_5_vl import \
1312
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1413
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401

vllm_ascend/multistream/ms_split.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,14 @@ def model_input_split_v1_attn(
304304
# the attn_mla kernel in torch npu only accept 128*128 attn mask
305305
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
306306
attn_state_pre = attn_state_post = attn_metadata.attn_state
307+
307308
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
308309
# should be none in decode only state
309310
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
310311
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly
311312
else:
312313
# chunked prefill
314+
assert attn_metadata.attn_mask is not None
313315
if has_prefill_pre:
314316
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill
315317
attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(

vllm_ascend/ops/moe_dispatcher/token_dispatcher.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import torch
2626
import torch_npu
27+
from torch import Tensor
2728
from vllm.distributed.parallel_state import get_ep_group
2829

2930
from vllm_ascend.distributed.tensor_parallel import (
@@ -279,7 +280,7 @@ def preprocess(self,
279280
"num_global_tokens_per_local_expert must be set before operations."
280281
)
281282
self.device_sync_point = "no_sync"
282-
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
283+
self.global_input_tokens_local_experts_indices: Tensor = torch.repeat_interleave(
283284
self.expert_ids_per_ep_rank,
284285
self.num_global_tokens_per_local_expert.ravel())
285286

@@ -314,6 +315,7 @@ def token_permutation(
314315

315316
# Permutation 1: input to AlltoAll input
316317
def alltoall_token_permutation1(hidden_states, routing_map):
318+
assert self.hidden_shape is not None
317319
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
318320
tokens_per_expert = self.preprocess(routing_map)
319321
if self.tp_ep_size > 1:
@@ -390,6 +392,7 @@ def preprocess_and_permtute1(self,
390392
self.top_indices = routing_map
391393
assert probs.dim() == 2, "Expected 2D tensor for probs"
392394
assert routing_map.dim() == 2, "Expected 2D tensor for routing map"
395+
assert self.hidden_shape is not None
393396

394397
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
395398
tokens_per_expert = self.preprocess(routing_map, with_sync=False)
@@ -401,6 +404,7 @@ def preprocess_and_permtute1(self,
401404
event = torch.npu.current_stream().record_event()
402405
self.perm1_finish_event = torch.npu.Event()
403406
with torch.npu.stream(self.overlap_stream):
407+
assert self.overlap_stream is not None
404408
self.overlap_stream.wait_event(event)
405409

406410
if shared_experts is not None:
@@ -418,7 +422,11 @@ def preprocess_and_permtute1(self,
418422
# repeat interleve will launch a sync on current_stream.
419423
if self.num_local_experts > 1:
420424
self.device_sync_point = "no_sync"
421-
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
425+
if self.num_global_tokens_per_local_expert is None:
426+
raise ValueError(
427+
"num_global_tokens_per_local_expert must be set before operations."
428+
)
429+
self.global_input_tokens_local_experts_indices: Tensor = torch.repeat_interleave(
422430
self.expert_ids_per_ep_rank,
423431
self.num_global_tokens_per_local_expert.ravel())
424432

@@ -441,6 +449,10 @@ def dispatch_alltoall(self):
441449
ep_group,
442450
)
443451
permute1_ep_all_to_all_handle.wait()
452+
if self.cached_permutated_local_input_tokens is None:
453+
raise ValueError(
454+
"cached_permutated_local_input_tokens must be set before operations."
455+
)
444456
self.cached_permutated_local_input_tokens.untyped_storage().resize_(0)
445457
self.cached_permutated_local_input_tokens = None
446458

0 commit comments

Comments
 (0)