Skip to content

[WIP][EPLB] Enable Llama4 EPLB #20901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

b8zhong
Copy link
Contributor

@b8zhong b8zhong commented Jul 14, 2025

Purpose

As a part of #20468, support EPLB for Llama4.

Followup of #20792 I did something weird in the Git history there.

Test Plan

WIP,

Test Result

vllm serve /models/Llama4-Scout-17B --tensor-parallel-size 8 --enable-eplb
vllm serve /fp8-llama/llama4scout-fp8/ \
  --tensor-parallel-size 8 \
  --max-model-len 2048 \
  --enable-expert-parallel \
  --enable-eplb \
  --num-redundant-experts 16 \
  --eplb-window-size 1000 \
  --eplb-step-interval 3000 \
  --trust-remote-code \
  --enforce-eager

WIP

(Optional) Documentation Update

WIP

@mergify mergify bot added the llama Related to Llama models label Jul 14, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @b8zhong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive support for Expert Parallel Load Balancing (EPLB) within the vLLM framework specifically for the Llama4 Mixture-of-Experts (MoE) model. It integrates EPLB capabilities into the core MoE layers, updates the model's architecture to manage expert distribution and load, and refines the weight loading process to accommodate redundant experts. This work is a significant step towards optimizing the performance and efficiency of Llama4 inference with expert parallelism.

Highlights

  • Expert Parallel Load Balancing (EPLB) for Llama4: Implemented core support for EPLB within the Llama4 model, enabling dynamic load balancing across experts in a Mixture-of-Experts (MoE) architecture.
  • Expanded FusedMoE Compatibility: The FusedMoE layer now supports EPLB for UnquantizedFusedMoEMethod, removing a previous NotImplementedError and allowing unquantized MoE layers to leverage EPLB.
  • Dynamic EPLB State Management: Introduced mechanisms within the Llama4 model to initialize and update EPLB-specific parameters (e.g., expert_load_view, logical_to_physical_map) at runtime, crucial for efficient expert routing.
  • Refined MoE Weight Loading: The weight loading process for Llama4 MoE layers has been refactored to correctly handle and distribute weights for both logical and redundant experts, ensuring proper initialization in an EPLB setup.
  • New EPLB Test Case: A dedicated test file (test_llama4_eplb.py) was added to validate the functionality of Llama4 with EPLB enabled, including specific configuration parameters.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@b8zhong b8zhong force-pushed the llama4-eplb-clean branch from 979b2bc to e657c3e Compare July 14, 2025 03:22
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables Expert Parallelism Load Balancing (EPLB) for the Llama4 model. The changes are mostly in vllm/model_executor/models/llama4.py to integrate the EPLB logic, and in vllm/model_executor/layers/fused_moe/layer.py to plumb through the necessary parameters. I've identified a potential memory leak in the state update logic and some areas for improving code clarity and robustness. Overall, the changes look good and are on the right track.

Comment on lines +658 to +672
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This method appends to self.expert_weights every time it's called. If set_eplb_state is invoked multiple times (e.g., during dynamic rebalancing), the self.expert_weights list will grow indefinitely, leading to a memory leak.

To prevent this, the list should be cleared at the beginning of the method.

Suggested change
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
self.expert_weights.clear()
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)

Comment on lines +365 to +367
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert for validating conditionally required arguments can be risky, as assertions can be disabled with the -O flag in Python. This could lead to silent failures in production if the necessary arguments aren't passed when enable_eplb is true. It's more robust to raise a ValueError with a clear message.

Suggested change
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
if any(p is None for p in (expert_load_view,
logical_to_physical_map,
logical_replica_count)):
raise ValueError(
"EPLB is enabled, but required arguments are missing.")

Comment on lines +368 to +369
# We need to create layers with enable_eplb parameter
# Store the original layer_type and override it with a lambda
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This pattern of using a nested function factory to inject parameters is clever but can be hard to follow for future maintainers. Adding a more descriptive comment explaining why this approach is necessary (i.e., to pass enable_eplb to each layer without modifying the generic make_layers utility) would greatly improve code clarity.

Suggested change
# We need to create layers with enable_eplb parameter
# Store the original layer_type and override it with a lambda
# We need to create layers with the enable_eplb parameter.
# To do this without modifying the generic `make_layers` utility,
# we store the original layer_type and override it with a lambda
# that injects the `enable_eplb` parameter into the constructor.

Comment on lines +424 to +428
feed_forward = self.layers[layer_idx].feed_forward
if hasattr(feed_forward, 'experts'):
expert_map = feed_forward.experts.expert_map
else:
expert_map = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This method, load_moe_expert_weights, appears to be dead code after the refactoring in the load_weights method. The new implementation in load_weights now handles expert weight loading directly. If this method is no longer used, it should be removed to improve maintainability and avoid confusion.

is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The in operator for substring matching can be brittle. For example, it would incorrectly match a weight named w1 in w10. A more robust approach would be to ensure you're matching a whole component of the name, for instance by checking for delimiters like dots around the weight_name.

Suggested change
if weight_name not in name:
if f".{weight_name}." not in f".{name}.":

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

b8zhong added 5 commits July 15, 2025 10:35
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
…TensorsW8A8Fp8MoECutlassMethod

Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
@b8zhong b8zhong force-pushed the llama4-eplb-clean branch from b6d9a87 to 76929b2 Compare July 15, 2025 14:35
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Co-authored-by: ztang2370 <ztang2370@gmail.com>
@b8zhong b8zhong force-pushed the llama4-eplb-clean branch from 76929b2 to 2a4e3ef Compare July 15, 2025 14:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llama Related to Llama models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant