@@ -201,6 +201,8 @@ def __init__(self, config: MoEDispatcherConfig):
201
201
self .cached_global_input_tokens = None
202
202
self .cached_shared_expert_output = None
203
203
self .tokens_per_expert = None
204
+ self .perm1_finish_event = None
205
+ self .global_input_tokens_local_experts_indices = None
204
206
205
207
if MoEAlltoAllSeqOverLapDispatcher .overlap_stream is None :
206
208
MoEAlltoAllSeqOverLapDispatcher .overlap_stream = torch .npu .Stream ()
@@ -280,7 +282,7 @@ def preprocess(self,
280
282
"num_global_tokens_per_local_expert must be set before operations."
281
283
)
282
284
self .device_sync_point = "no_sync"
283
- self .global_input_tokens_local_experts_indices : Tensor = torch .repeat_interleave (
285
+ self .global_input_tokens_local_experts_indices = torch .repeat_interleave (
284
286
self .expert_ids_per_ep_rank ,
285
287
self .num_global_tokens_per_local_expert .ravel ())
286
288
@@ -426,7 +428,7 @@ def preprocess_and_permtute1(self,
426
428
raise ValueError (
427
429
"num_global_tokens_per_local_expert must be set before operations."
428
430
)
429
- self .global_input_tokens_local_experts_indices : Tensor = torch .repeat_interleave (
431
+ self .global_input_tokens_local_experts_indices = torch .repeat_interleave (
430
432
self .expert_ids_per_ep_rank ,
431
433
self .num_global_tokens_per_local_expert .ravel ())
432
434
@@ -462,6 +464,7 @@ def permute2(self):
462
464
global_input_tokens , self .reversed_global_input_permutation_mapping = torch_npu .npu_moe_token_permute (
463
465
self .cached_global_input_tokens ,
464
466
self .global_input_tokens_local_experts_indices )
467
+ assert self .cached_global_input_tokens is not None
465
468
self .cached_global_input_tokens .untyped_storage ().resize_ (0 )
466
469
self .cached_global_input_tokens = None
467
470
0 commit comments