Skip to content

[WIP][EPLB] Enable Llama4 EPLB #20901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions test_llama4_eplb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)


def main():
# Create an LLM with EPLB parameters.
llm = LLM(
model="/fp8-llama/llama4scout-fp8/",
tensor_parallel_size=8,
max_model_len=2048,
enable_expert_parallel=True,
enable_eplb=True,
num_redundant_experts=16,
eplb_window_size=1000,
eplb_step_interval=3000,
trust_remote_code=True,
enforce_eager=True,
)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)


if __name__ == "__main__":
main()
37 changes: 19 additions & 18 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,10 @@ def apply(
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
Comment on lines +365 to +367
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert for validating conditionally required arguments can be risky, as assertions can be disabled with the -O flag in Python. This could lead to silent failures in production if the necessary arguments aren't passed when enable_eplb is true. It's more robust to raise a ValueError with a clear message.

Suggested change
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
if any(p is None for p in (expert_load_view,
logical_to_physical_map,
logical_replica_count)):
raise ValueError(
"EPLB is enabled, but required arguments are missing.")

assert isinstance(layer, FusedMoE)

return self.forward(
x=x,
Expand All @@ -380,7 +382,12 @@ def apply(
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input,
enable_eplb=enable_eplb,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)

def forward_cuda(
self,
Expand All @@ -399,6 +406,10 @@ def forward_cuda(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:

topk_weights, topk_ids = FusedMoE.select_experts(
Expand All @@ -412,7 +423,11 @@ def forward_cuda(
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count)

if self.rocm_aiter_moe_enabled:
return self.rocm_aiter_fused_experts(
Expand Down Expand Up @@ -750,20 +765,6 @@ def __init__(
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method

if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod)
if not isinstance(quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
# design causes duplicated work when extending to new
# quantization methods, so I'm leaving it for now.
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError("EPLB is only supported for FP8 "
"quantization for now.")

moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,10 @@ def apply(
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
"`CompressedTensorsW8A8Fp8MoEMethod` yet.")
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
Expand All @@ -649,6 +650,11 @@ def apply(
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)

if self.rocm_aiter_moe_enabled:
Expand Down Expand Up @@ -913,9 +919,10 @@ def apply(
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
"`CompressedTensorsW8A8Fp8MoECutlassMethod` yet.")
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
Expand All @@ -927,7 +934,12 @@ def apply(
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count)

a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
Expand Down
Loading
Loading