@@ -184,7 +184,6 @@ def flatten_tp_across_dp(dp_rank: int):
184
184
# Adapted from pplx-kernels tests/all_to_all_utils.py
185
185
@dataclass
186
186
class MoEConfig :
187
- max_num_tokens : int
188
187
num_experts : int
189
188
experts_per_token : int
190
189
hidden_dim : int
@@ -347,33 +346,18 @@ def select_gemm_impl(
347
346
all2all_manager = get_ep_group ().device_communicator .all2all_manager
348
347
assert all2all_manager is not None
349
348
350
- experts : Optional [FusedMoEPermuteExpertsUnpermute ] = None
351
-
352
349
if isinstance (prepare_finalize ,
353
350
(BatchedPrepareAndFinalize , PplxPrepareAndFinalize )):
354
351
logger .debug ("BatchedTritonExperts %s" , self .moe )
355
- experts = BatchedTritonExperts (
352
+ return BatchedTritonExperts (
356
353
max_num_tokens = MOE_DP_CHUNK_SIZE ,
357
354
world_size = all2all_manager .world_size ,
358
355
# dp_size actually means tp_size, bug in pplx kernels
359
356
dp_size = all2all_manager .tp_group .world_size ,
360
- use_fp8_w8a8 = False ,
361
- use_int8_w8a8 = False ,
362
- use_int8_w8a16 = False ,
363
- use_int4_w4a16 = False ,
364
- block_shape = None ,
365
357
)
366
358
else :
367
359
logger .debug ("TritonExperts %s" , self .moe )
368
- experts = TritonExperts (
369
- use_fp8_w8a8 = False ,
370
- use_int8_w8a8 = False ,
371
- use_int8_w8a16 = False ,
372
- use_int4_w4a16 = False ,
373
- block_shape = None ,
374
- per_channel_quant = False ,
375
- )
376
- return experts
360
+ return TritonExperts ()
377
361
378
362
def create_weights (self , layer : torch .nn .Module , num_experts : int ,
379
363
hidden_size : int , intermediate_size_per_partition : int ,
@@ -472,35 +456,6 @@ def apply(
472
456
activation = activation ,
473
457
apply_router_weight_on_input = apply_router_weight_on_input )
474
458
475
- def set_prepare_finalize (
476
- self ,
477
- dp_size : int ,
478
- world_size : int ,
479
- prepare_finalize : FusedMoEPrepareAndFinalize ,
480
- ) -> bool :
481
- assert self .fused_experts == fused_experts
482
-
483
- experts : Optional [FusedMoEPermuteExpertsUnpermute ] = None
484
-
485
- if isinstance (prepare_finalize ,
486
- (BatchedPrepareAndFinalize , PplxPrepareAndFinalize )):
487
- logger .debug ("BatchedTritonExperts %s" , self .moe )
488
- experts = BatchedTritonExperts (
489
- max_num_tokens = MOE_DP_CHUNK_SIZE ,
490
- world_size = world_size ,
491
- dp_size = dp_size ,
492
- )
493
- else :
494
- logger .debug ("TritonExperts %s" , self .moe )
495
- experts = TritonExperts ()
496
-
497
- self .fused_experts = FusedMoEModularKernel (
498
- prepare_finalize ,
499
- experts ,
500
- )
501
-
502
- return True
503
-
504
459
def forward_cuda (
505
460
self ,
506
461
layer : torch .nn .Module ,
@@ -815,16 +770,14 @@ def __init__(
815
770
from vllm_hpu_extension .ops import DynamicFusedMOE
816
771
self .hpu_fused_moe = DynamicFusedMOE (self .global_num_experts )
817
772
818
- logger .debug ("Model dtype = %s" , vllm_config .model_config .dtype )
819
-
820
773
moe = MoEConfig (
821
- max_num_tokens = MOE_DP_CHUNK_SIZE ,
822
774
num_experts = self .global_num_experts ,
823
775
experts_per_token = top_k ,
824
776
hidden_dim = hidden_size ,
825
777
num_local_experts = self .local_num_experts ,
826
778
moe_parallel_config = self .moe_parallel_config ,
827
- in_dtype = moe .in_dtype ,
779
+ in_dtype = vllm_config .model_config .dtype ,
780
+ max_num_tokens = MOE_DP_CHUNK_SIZE ,
828
781
)
829
782
self .moe_config = moe
830
783
self .quant_config = quant_config
@@ -1281,7 +1234,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
1281
1234
1282
1235
assert (self .batched_hidden_states .size (0 ) # type: ignore
1283
1236
>= chunk_size )
1284
- assert (self .batched_router_logits .size (0 ) # type: ignore
1237
+ assert (self .batched_router_logits .size (0 ) # type: ignore
1285
1238
>= chunk_size )
1286
1239
staged_hidden_states = self .batched_hidden_states [:
1287
1240
chunk_size , :] # type: ignore
0 commit comments