From ca9e6a9c29b680b152869ca770e20972e5eacf53 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Sun, 13 Jul 2025 19:48:19 -0400 Subject: [PATCH 1/6] misc: add test script for easier testing Signed-off-by: Brayden Zhong --- test_llama4_eplb.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 test_llama4_eplb.py diff --git a/test_llama4_eplb.py b/test_llama4_eplb.py new file mode 100644 index 00000000000..7e48a0a7b60 --- /dev/null +++ b/test_llama4_eplb.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + +def main(): + # Create an LLM with EPLB parameters. + llm = LLM( + model="/fp8-llama/llama4scout-fp8/", + tensor_parallel_size=8, + enable_expert_parallel=True, + enable_eplb=True, + num_redundant_experts=16, + eplb_window_size=1000, + eplb_step_interval=3000, + trust_remote_code=True, + enforce_eager=True, + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() From 3efeaee596931ad7f8675156a06ad64a695b88c7 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Sun, 13 Jul 2025 21:31:16 -0400 Subject: [PATCH 2/6] feat: add the 1 shared expert in consideration Signed-off-by: Brayden Zhong --- vllm/model_executor/models/llama4.py | 190 +++++++++++++++++++++++---- 1 file changed, 163 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index fab1c163ac2..1712b8eb671 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -18,7 +18,7 @@ # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" from collections.abc import Iterable -from typing import Any, Optional +from typing import Any, Optional, Union, cast import torch from torch import nn @@ -26,8 +26,8 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -38,6 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from .interfaces import MixtureOfExperts from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk, is_pp_missing_parameter) @@ -57,14 +58,22 @@ def custom_routing_function( router_scores = torch.sigmoid(router_scores.float()) return (router_scores, router_indices.to(torch.int32)) - def __init__(self, - config: Llama4TextConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: Llama4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.top_k = config.num_experts_per_tok + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_local_experts + intermediate_size_moe = config.intermediate_size self.router = ReplicatedLinear(config.hidden_size, config.num_local_experts, @@ -72,6 +81,23 @@ def __init__(self, quant_config=None, prefix=f"{prefix}.router") + # Load balancing + + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + self.experts = FusedMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, @@ -82,7 +108,10 @@ def __init__(self, reduce_results=False, renormalize=False, quant_config=quant_config, - prefix=f"{prefix}.experts") + prefix=f"{prefix}.experts", + enable_eplb=enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) self.shared_expert = LlamaMLP( hidden_size=config.hidden_size, @@ -229,7 +258,8 @@ def forward( k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype) # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) - # to NoPE layers, where the inference-time temperature tuning function + # to NoPE layers, where the inference-time temperature tuning + # function # is customized to not affect short context # while working at very long context # https://arxiv.org/abs/2501.19399 @@ -252,6 +282,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + enable_eplb: bool = False, ) -> None: super().__init__() @@ -278,10 +309,11 @@ def __init__( is_moe_layer = config.interleave_moe_layer_step > 0 and ( self.layer_idx + 1) % config.interleave_moe_layer_step == 0 if is_moe_layer: - self.feed_forward = Llama4MoE( + self.feed_forward: Union[Llama4MoE, LlamaMLP] = Llama4MoE( config=config, quant_config=quant_config, prefix=f"{prefix}.feed_forward", + enable_eplb=enable_eplb, ) else: self.feed_forward = LlamaMLP( @@ -329,9 +361,26 @@ def __init__(self, prefix: str = "", layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer): self.num_experts = vllm_config.model_config.hf_config.num_local_experts + self.num_redundant_experts = ( + vllm_config.parallel_config.num_redundant_experts) + self.enable_eplb = vllm_config.parallel_config.enable_eplb + + # We need to create layers with enable_eplb parameter + # Store the original layer_type and override it with a lambda + original_layer_type = layer_type + + def create_layer(prefix): + config = cast(Llama4TextConfig, vllm_config.model_config.hf_config) + return original_layer_type(config=config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + prefix=prefix, + enable_eplb=self.enable_eplb) + + # Call parent init with our custom layer factory super().__init__(vllm_config=vllm_config, prefix=prefix, - layer_type=layer_type) + layer_type=cast(type[nn.Module], create_layer)) def load_moe_expert_weights( self, @@ -370,8 +419,11 @@ def load_moe_expert_weights( new_loaded_weight = new_loaded_weight.transpose(-1, -2) layer_idx = extract_layer_index(name) # EP mapping - expert_map = self.layers[ - layer_idx].feed_forward.experts.expert_map + feed_forward = self.layers[layer_idx].feed_forward + if hasattr(feed_forward, 'experts'): + expert_map = feed_forward.experts.expert_map + else: + expert_map = None if expert_map is not None: local_expert_indices = (expert_map != -1) \ .nonzero() \ @@ -390,6 +442,7 @@ def load_moe_expert_weights( loaded_params.add(full_param_name) expert_param_loaded = True + is_expert = True return expert_param_loaded def load_weights(self, weights: Iterable[tuple[str, @@ -407,7 +460,9 @@ def load_weights(self, weights: Iterable[tuple[str, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.num_experts) + num_experts=self.num_experts, + num_redundant_experts=self.num_redundant_experts, + ) expert_params_mapping_fused = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_up_proj", ckpt_down_proj_name="down_proj", @@ -451,18 +506,54 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight) else: weight_loader(param, loaded_weight, shard_id) + is_expert = False loaded_params.add(name) break else: - moe_loaded = self.load_moe_expert_weights( - name, - loaded_weight, - params_dict, - loaded_params, - expert_params_mapping, - fused=fused_experts_params) - - if not moe_loaded: + # First try to handle as expert weight + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + # Skip loading extra parameters for GPTQ/modelopt models. + if ((name_mapped.endswith(".bias") + or name_mapped.endswith("_bias")) + and name_mapped not in params_dict): + continue + + param = params_dict[name_mapped] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id) + loaded_params.add(name_mapped) + is_expert = True + break + else: + # If we've identified this as an expert weight but couldn't + # load it + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Not an expert weight, continue with regular loading if is_pp_missing_parameter(name, self): continue @@ -500,6 +591,7 @@ def load_weights(self, weights: Iterable[tuple[str, # Regular weight loader (handles both # param.weight_loader and default_weight_loader) weight_loader(param, loaded_weight) + is_expert = True loaded_params.add(name) continue @@ -507,11 +599,12 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + is_expert = False loaded_params.add(name) return loaded_params -class Llama4ForCausalLM(LlamaForCausalLM): +class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -525,14 +618,57 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # enable temperature tuning by default when max_model_len > 32K default_attn_temperature_tuning = \ vllm_config.model_config.max_model_len > 32768 - vllm_config.model_config.hf_config.attn_temperature_tuning \ - = gen_config.get( - "attn_temperature_tuning", default_attn_temperature_tuning) + vllm_config.model_config.hf_config.attn_temperature_tuning = \ + gen_config.get("attn_temperature_tuning", + default_attn_temperature_tuning) super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer) + self.expert_weights = [] + + # Set MoE hyperparameters + self.moe_layers: list[FusedMoE] = [] + for layer in self.model.layers: + assert isinstance(layer, Llama4DecoderLayer) + if isinstance(layer.feed_forward, Llama4MoE): + self.moe_layers.append(layer.feed_forward.experts) + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + + example_moe = None + for layer_idx in range(self.config.num_hidden_layers): + layer = self.model.layers[layer_idx] + if isinstance(layer.feed_forward, Llama4MoE): + example_moe = layer.feed_forward + break + assert example_moe is not None + + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + self.num_shared_experts = 1 + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + def _init_model(self, vllm_config: VllmConfig, prefix: str = "", From eda01afc26f59641cd245020a583607d4c1e7a88 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Sun, 13 Jul 2025 21:46:22 -0400 Subject: [PATCH 3/6] use the correct configuration Signed-off-by: Brayden Zhong --- test_llama4_eplb.py | 1 + vllm/model_executor/models/llama4.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/test_llama4_eplb.py b/test_llama4_eplb.py index 7e48a0a7b60..2bf3a6ddecb 100644 --- a/test_llama4_eplb.py +++ b/test_llama4_eplb.py @@ -19,6 +19,7 @@ def main(): llm = LLM( model="/fp8-llama/llama4scout-fp8/", tensor_parallel_size=8, + max_model_len=2048, enable_expert_parallel=True, enable_eplb=True, num_redundant_experts=16, diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 1712b8eb671..c9d9dbdc788 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -369,11 +369,13 @@ def __init__(self, # Store the original layer_type and override it with a lambda original_layer_type = layer_type - def create_layer(prefix): + def create_layer(config, cache_config, quant_config, prefix): + # We use the config from vllm_config instead of the passed one + # to ensure we get the Llama4TextConfig type config = cast(Llama4TextConfig, vllm_config.model_config.hf_config) return original_layer_type(config=config, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, + cache_config=cache_config, + quant_config=quant_config, prefix=prefix, enable_eplb=self.enable_eplb) From c0b437f079e182cd3a94f9fc2e3659adfbfc21e8 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Sun, 13 Jul 2025 22:09:37 -0400 Subject: [PATCH 4/6] Add support for unquantizedFusedMoe Signed-off-by: Brayden Zhong --- vllm/model_executor/layers/fused_moe/layer.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index da772c11155..1d7863117a8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -362,8 +362,10 @@ def apply( logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `UnquantizedFusedMoEMethod` yet.") + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) return self.forward( x=x, @@ -380,7 +382,12 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + enable_eplb=enable_eplb, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) def forward_cuda( self, @@ -399,6 +406,10 @@ def forward_cuda( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( @@ -412,7 +423,11 @@ def forward_cuda( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count) if self.rocm_aiter_moe_enabled: return self.rocm_aiter_fused_experts( @@ -753,7 +768,8 @@ def __init__( if self.enable_eplb: from vllm.model_executor.layers.quantization.fp8 import ( Fp8MoEMethod) - if not isinstance(quant_method, Fp8MoEMethod): + if not isinstance(quant_method, Fp8MoEMethod) and not isinstance( + quant_method, UnquantizedFusedMoEMethod): # TODO: Add support for additional quantization methods. # The implementation for other quantization methods does not # contain essential differences, but the current quant API From 14cfca2c052433964cbc7ff6285b6a61dd271bc6 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Tue, 15 Jul 2025 09:02:07 -0400 Subject: [PATCH 5/6] feat: support CompressedTensorsW8A8Fp8MoECutlassMethod and CompressedTensorsW8A8Fp8MoECutlassMethod Signed-off-by: Brayden Zhong --- .../compressed_tensors_moe.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index baf4fec3cc6..6215f4f947f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -633,9 +633,10 @@ def apply( logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: if enable_eplb: - raise NotImplementedError( - "EPLB not supported for " - "`CompressedTensorsW8A8Fp8MoEMethod` yet.") + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -649,6 +650,11 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, ) if self.rocm_aiter_moe_enabled: @@ -913,9 +919,10 @@ def apply( logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: if enable_eplb: - raise NotImplementedError( - "EPLB not supported for " - "`CompressedTensorsW8A8Fp8MoECutlassMethod` yet.") + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -927,7 +934,12 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count) a1_scale = layer.w13_input_scale a2_scale = layer.w2_input_scale From 2a4e3ef1ef34c854986dcd00d546250d124bb315 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Tue, 15 Jul 2025 09:29:59 -0400 Subject: [PATCH 6/6] feat: remove blocker Signed-off-by: Brayden Zhong Co-authored-by: ztang2370 --- vllm/model_executor/layers/fused_moe/layer.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1d7863117a8..730c0690866 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -765,21 +765,6 @@ def __init__( assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method - if self.enable_eplb: - from vllm.model_executor.layers.quantization.fp8 import ( - Fp8MoEMethod) - if not isinstance(quant_method, Fp8MoEMethod) and not isinstance( - quant_method, UnquantizedFusedMoEMethod): - # TODO: Add support for additional quantization methods. - # The implementation for other quantization methods does not - # contain essential differences, but the current quant API - # design causes duplicated work when extending to new - # quantization methods, so I'm leaving it for now. - # If you plan to add support for more quantization methods, - # please refer to the implementation in `Fp8MoEMethod`. - raise NotImplementedError("EPLB is only supported for FP8 " - "quantization for now.") - moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size,