Skip to content

Commit e657c3e

Browse files
committed
Add support for unquantizedFusedMoe
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
1 parent 99a3c18 commit e657c3e

File tree

1 file changed

+21
-5
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+21
-5
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,10 @@ def apply(
362362
logical_replica_count: Optional[torch.Tensor] = None,
363363
) -> torch.Tensor:
364364
if enable_eplb:
365-
raise NotImplementedError(
366-
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
365+
assert expert_load_view is not None
366+
assert logical_to_physical_map is not None
367+
assert logical_replica_count is not None
368+
assert isinstance(layer, FusedMoE)
367369

368370
return self.forward(
369371
x=x,
@@ -380,7 +382,12 @@ def apply(
380382
scoring_func=scoring_func,
381383
e_score_correction_bias=e_score_correction_bias,
382384
activation=activation,
383-
apply_router_weight_on_input=apply_router_weight_on_input)
385+
apply_router_weight_on_input=apply_router_weight_on_input,
386+
enable_eplb=enable_eplb,
387+
expert_load_view=expert_load_view,
388+
logical_to_physical_map=logical_to_physical_map,
389+
logical_replica_count=logical_replica_count,
390+
)
384391

385392
def forward_cuda(
386393
self,
@@ -399,6 +406,10 @@ def forward_cuda(
399406
e_score_correction_bias: Optional[torch.Tensor] = None,
400407
apply_router_weight_on_input: bool = False,
401408
activation: str = "silu",
409+
enable_eplb: bool = False,
410+
expert_load_view: Optional[torch.Tensor] = None,
411+
logical_to_physical_map: Optional[torch.Tensor] = None,
412+
logical_replica_count: Optional[torch.Tensor] = None,
402413
) -> torch.Tensor:
403414

404415
topk_weights, topk_ids = FusedMoE.select_experts(
@@ -412,7 +423,11 @@ def forward_cuda(
412423
custom_routing_function=custom_routing_function,
413424
scoring_func=scoring_func,
414425
e_score_correction_bias=e_score_correction_bias,
415-
indices_type=self.topk_indices_dtype)
426+
indices_type=self.topk_indices_dtype,
427+
enable_eplb=enable_eplb,
428+
expert_load_view=expert_load_view,
429+
logical_to_physical_map=logical_to_physical_map,
430+
logical_replica_count=logical_replica_count)
416431

417432
if self.rocm_aiter_moe_enabled:
418433
return self.rocm_aiter_fused_experts(
@@ -753,7 +768,8 @@ def __init__(
753768
if self.enable_eplb:
754769
from vllm.model_executor.layers.quantization.fp8 import (
755770
Fp8MoEMethod)
756-
if not isinstance(quant_method, Fp8MoEMethod):
771+
if not isinstance(quant_method, Fp8MoEMethod) and not isinstance(
772+
quant_method, UnquantizedFusedMoEMethod):
757773
# TODO: Add support for additional quantization methods.
758774
# The implementation for other quantization methods does not
759775
# contain essential differences, but the current quant API

0 commit comments

Comments
 (0)