24
24
25
25
import torch
26
26
import torch_npu
27
+ from torch import Tensor
27
28
from vllm .distributed .parallel_state import get_ep_group
28
29
29
30
from vllm_ascend .distributed .tensor_parallel import (
@@ -279,7 +280,7 @@ def preprocess(self,
279
280
"num_global_tokens_per_local_expert must be set before operations."
280
281
)
281
282
self .device_sync_point = "no_sync"
282
- self .global_input_tokens_local_experts_indices = torch .repeat_interleave (
283
+ self .global_input_tokens_local_experts_indices : Tensor = torch .repeat_interleave (
283
284
self .expert_ids_per_ep_rank ,
284
285
self .num_global_tokens_per_local_expert .ravel ())
285
286
@@ -314,6 +315,7 @@ def token_permutation(
314
315
315
316
# Permutation 1: input to AlltoAll input
316
317
def alltoall_token_permutation1 (hidden_states , routing_map ):
318
+ assert self .hidden_shape is not None
317
319
hidden_states = hidden_states .view (- 1 , self .hidden_shape [- 1 ])
318
320
tokens_per_expert = self .preprocess (routing_map )
319
321
if self .tp_ep_size > 1 :
@@ -390,6 +392,7 @@ def preprocess_and_permtute1(self,
390
392
self .top_indices = routing_map
391
393
assert probs .dim () == 2 , "Expected 2D tensor for probs"
392
394
assert routing_map .dim () == 2 , "Expected 2D tensor for routing map"
395
+ assert self .hidden_shape is not None
393
396
394
397
hidden_states = hidden_states .view (- 1 , self .hidden_shape [- 1 ])
395
398
tokens_per_expert = self .preprocess (routing_map , with_sync = False )
@@ -401,6 +404,7 @@ def preprocess_and_permtute1(self,
401
404
event = torch .npu .current_stream ().record_event ()
402
405
self .perm1_finish_event = torch .npu .Event ()
403
406
with torch .npu .stream (self .overlap_stream ):
407
+ assert self .overlap_stream is not None
404
408
self .overlap_stream .wait_event (event )
405
409
406
410
if shared_experts is not None :
@@ -418,7 +422,11 @@ def preprocess_and_permtute1(self,
418
422
# repeat interleve will launch a sync on current_stream.
419
423
if self .num_local_experts > 1 :
420
424
self .device_sync_point = "no_sync"
421
- self .global_input_tokens_local_experts_indices = torch .repeat_interleave (
425
+ if self .num_global_tokens_per_local_expert is None :
426
+ raise ValueError (
427
+ "num_global_tokens_per_local_expert must be set before operations."
428
+ )
429
+ self .global_input_tokens_local_experts_indices : Tensor = torch .repeat_interleave (
422
430
self .expert_ids_per_ep_rank ,
423
431
self .num_global_tokens_per_local_expert .ravel ())
424
432
@@ -441,6 +449,10 @@ def dispatch_alltoall(self):
441
449
ep_group ,
442
450
)
443
451
permute1_ep_all_to_all_handle .wait ()
452
+ if self .cached_permutated_local_input_tokens is None :
453
+ raise ValueError (
454
+ "cached_permutated_local_input_tokens must be set before operations."
455
+ )
444
456
self .cached_permutated_local_input_tokens .untyped_storage ().resize_ (0 )
445
457
self .cached_permutated_local_input_tokens = None
446
458
0 commit comments