@@ -362,8 +362,10 @@ def apply(
362
362
logical_replica_count : Optional [torch .Tensor ] = None ,
363
363
) -> torch .Tensor :
364
364
if enable_eplb :
365
- raise NotImplementedError (
366
- "EPLB not supported for `UnquantizedFusedMoEMethod` yet." )
365
+ assert expert_load_view is not None
366
+ assert logical_to_physical_map is not None
367
+ assert logical_replica_count is not None
368
+ assert isinstance (layer , FusedMoE )
367
369
368
370
return self .forward (
369
371
x = x ,
@@ -380,7 +382,12 @@ def apply(
380
382
scoring_func = scoring_func ,
381
383
e_score_correction_bias = e_score_correction_bias ,
382
384
activation = activation ,
383
- apply_router_weight_on_input = apply_router_weight_on_input )
385
+ apply_router_weight_on_input = apply_router_weight_on_input ,
386
+ enable_eplb = enable_eplb ,
387
+ expert_load_view = expert_load_view ,
388
+ logical_to_physical_map = logical_to_physical_map ,
389
+ logical_replica_count = logical_replica_count ,
390
+ )
384
391
385
392
def forward_cuda (
386
393
self ,
@@ -399,6 +406,10 @@ def forward_cuda(
399
406
e_score_correction_bias : Optional [torch .Tensor ] = None ,
400
407
apply_router_weight_on_input : bool = False ,
401
408
activation : str = "silu" ,
409
+ enable_eplb : bool = False ,
410
+ expert_load_view : Optional [torch .Tensor ] = None ,
411
+ logical_to_physical_map : Optional [torch .Tensor ] = None ,
412
+ logical_replica_count : Optional [torch .Tensor ] = None ,
402
413
) -> torch .Tensor :
403
414
404
415
topk_weights , topk_ids = FusedMoE .select_experts (
@@ -412,7 +423,11 @@ def forward_cuda(
412
423
custom_routing_function = custom_routing_function ,
413
424
scoring_func = scoring_func ,
414
425
e_score_correction_bias = e_score_correction_bias ,
415
- indices_type = self .topk_indices_dtype )
426
+ indices_type = self .topk_indices_dtype ,
427
+ enable_eplb = enable_eplb ,
428
+ expert_load_view = expert_load_view ,
429
+ logical_to_physical_map = logical_to_physical_map ,
430
+ logical_replica_count = logical_replica_count )
416
431
417
432
if self .rocm_aiter_moe_enabled :
418
433
return self .rocm_aiter_fused_experts (
@@ -753,7 +768,8 @@ def __init__(
753
768
if self .enable_eplb :
754
769
from vllm .model_executor .layers .quantization .fp8 import (
755
770
Fp8MoEMethod )
756
- if not isinstance (quant_method , Fp8MoEMethod ):
771
+ if not isinstance (quant_method , Fp8MoEMethod ) and not isinstance (
772
+ quant_method , UnquantizedFusedMoEMethod ):
757
773
# TODO: Add support for additional quantization methods.
758
774
# The implementation for other quantization methods does not
759
775
# contain essential differences, but the current quant API
0 commit comments