Skip to content

Commit e19cced

Browse files
committed
feat: add the 1 shared expert in consideration
1 parent 9bb7599 commit e19cced

File tree

1 file changed

+163
-27
lines changed

1 file changed

+163
-27
lines changed

vllm/model_executor/models/llama4.py

Lines changed: 163 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@
1818
# limitations under the License.
1919
"""Inference-only LLaMA model compatible with HuggingFace weights."""
2020
from collections.abc import Iterable
21-
from typing import Any, Optional
21+
from typing import Any, Optional, Union, cast
2222

2323
import torch
2424
from torch import nn
2525
from transformers import Llama4TextConfig
2626

2727
from vllm.attention import Attention
2828
from vllm.compilation.decorators import support_torch_compile
29-
from vllm.config import CacheConfig, VllmConfig
30-
from vllm.distributed import get_tensor_model_parallel_world_size
29+
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
30+
from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size
3131
from vllm.model_executor.layers.fused_moe import FusedMoE
3232
from vllm.model_executor.layers.layernorm import RMSNorm
3333
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@@ -38,6 +38,7 @@
3838
from vllm.model_executor.model_loader.weight_utils import (
3939
default_weight_loader, maybe_remap_kv_scale_name)
4040

41+
from .interfaces import MixtureOfExperts
4142
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
4243
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
4344
is_pp_missing_parameter)
@@ -57,21 +58,46 @@ def custom_routing_function(
5758
router_scores = torch.sigmoid(router_scores.float())
5859
return (router_scores, router_indices.to(torch.int32))
5960

60-
def __init__(self,
61-
config: Llama4TextConfig,
62-
quant_config: Optional[QuantizationConfig] = None,
63-
prefix: str = ""):
61+
def __init__(
62+
self,
63+
config: Llama4TextConfig,
64+
quant_config: Optional[QuantizationConfig] = None,
65+
prefix: str = "",
66+
enable_eplb: bool = False,
67+
):
6468
super().__init__()
6569
self.tp_size = get_tensor_model_parallel_world_size()
6670
self.top_k = config.num_experts_per_tok
6771

72+
self.ep_group = get_ep_group().device_group
73+
self.ep_rank = self.ep_group.rank()
74+
self.ep_size = self.ep_group.size()
75+
self.n_routed_experts = config.num_local_experts
76+
6877
intermediate_size_moe = config.intermediate_size
6978
self.router = ReplicatedLinear(config.hidden_size,
7079
config.num_local_experts,
7180
bias=False,
7281
quant_config=None,
7382
prefix=f"{prefix}.router")
7483

84+
# Load balancing
85+
86+
vllm_config = get_current_vllm_config()
87+
parallel_config = vllm_config.parallel_config
88+
self.enable_eplb = enable_eplb
89+
90+
self.n_logical_experts = self.n_routed_experts
91+
self.n_redundant_experts = parallel_config.num_redundant_experts
92+
self.n_physical_experts = (self.n_logical_experts +
93+
self.n_redundant_experts)
94+
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
95+
96+
self.physical_expert_start = (self.ep_rank *
97+
self.n_local_physical_experts)
98+
self.physical_expert_end = (self.physical_expert_start +
99+
self.n_local_physical_experts)
100+
75101
self.experts = FusedMoE(
76102
num_experts=config.num_local_experts,
77103
top_k=config.num_experts_per_tok,
@@ -82,7 +108,10 @@ def __init__(self,
82108
reduce_results=False,
83109
renormalize=False,
84110
quant_config=quant_config,
85-
prefix=f"{prefix}.experts")
111+
prefix=f"{prefix}.experts",
112+
enable_eplb=enable_eplb,
113+
num_redundant_experts=self.n_redundant_experts,
114+
)
86115

87116
self.shared_expert = LlamaMLP(
88117
hidden_size=config.hidden_size,
@@ -229,7 +258,8 @@ def forward(
229258
k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
230259

231260
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
232-
# to NoPE layers, where the inference-time temperature tuning function
261+
# to NoPE layers, where the inference-time temperature tuning
262+
# function
233263
# is customized to not affect short context
234264
# while working at very long context
235265
# https://arxiv.org/abs/2501.19399
@@ -252,6 +282,7 @@ def __init__(
252282
cache_config: Optional[CacheConfig] = None,
253283
quant_config: Optional[QuantizationConfig] = None,
254284
prefix: str = "",
285+
enable_eplb: bool = False,
255286
) -> None:
256287
super().__init__()
257288

@@ -278,10 +309,11 @@ def __init__(
278309
is_moe_layer = config.interleave_moe_layer_step > 0 and (
279310
self.layer_idx + 1) % config.interleave_moe_layer_step == 0
280311
if is_moe_layer:
281-
self.feed_forward = Llama4MoE(
312+
self.feed_forward: Union[Llama4MoE, LlamaMLP] = Llama4MoE(
282313
config=config,
283314
quant_config=quant_config,
284315
prefix=f"{prefix}.feed_forward",
316+
enable_eplb=enable_eplb,
285317
)
286318
else:
287319
self.feed_forward = LlamaMLP(
@@ -329,9 +361,26 @@ def __init__(self,
329361
prefix: str = "",
330362
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
331363
self.num_experts = vllm_config.model_config.hf_config.num_local_experts
364+
self.num_redundant_experts = (
365+
vllm_config.parallel_config.num_redundant_experts)
366+
self.enable_eplb = vllm_config.parallel_config.enable_eplb
367+
368+
# We need to create layers with enable_eplb parameter
369+
# Store the original layer_type and override it with a lambda
370+
original_layer_type = layer_type
371+
372+
def create_layer(prefix):
373+
config = cast(Llama4TextConfig, vllm_config.model_config.hf_config)
374+
return original_layer_type(config=config,
375+
cache_config=vllm_config.cache_config,
376+
quant_config=vllm_config.quant_config,
377+
prefix=prefix,
378+
enable_eplb=self.enable_eplb)
379+
380+
# Call parent init with our custom layer factory
332381
super().__init__(vllm_config=vllm_config,
333382
prefix=prefix,
334-
layer_type=layer_type)
383+
layer_type=cast(type[nn.Module], create_layer))
335384

336385
def load_moe_expert_weights(
337386
self,
@@ -370,8 +419,11 @@ def load_moe_expert_weights(
370419
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
371420
layer_idx = extract_layer_index(name)
372421
# EP mapping
373-
expert_map = self.layers[
374-
layer_idx].feed_forward.experts.expert_map
422+
feed_forward = self.layers[layer_idx].feed_forward
423+
if hasattr(feed_forward, 'experts'):
424+
expert_map = feed_forward.experts.expert_map
425+
else:
426+
expert_map = None
375427
if expert_map is not None:
376428
local_expert_indices = (expert_map != -1) \
377429
.nonzero() \
@@ -390,6 +442,7 @@ def load_moe_expert_weights(
390442

391443
loaded_params.add(full_param_name)
392444
expert_param_loaded = True
445+
is_expert = True
393446
return expert_param_loaded
394447

395448
def load_weights(self, weights: Iterable[tuple[str,
@@ -407,7 +460,9 @@ def load_weights(self, weights: Iterable[tuple[str,
407460
ckpt_gate_proj_name="gate_proj",
408461
ckpt_down_proj_name="down_proj",
409462
ckpt_up_proj_name="up_proj",
410-
num_experts=self.num_experts)
463+
num_experts=self.num_experts,
464+
num_redundant_experts=self.num_redundant_experts,
465+
)
411466
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
412467
ckpt_gate_proj_name="gate_up_proj",
413468
ckpt_down_proj_name="down_proj",
@@ -451,18 +506,54 @@ def load_weights(self, weights: Iterable[tuple[str,
451506
weight_loader(param, loaded_weight)
452507
else:
453508
weight_loader(param, loaded_weight, shard_id)
509+
is_expert = False
454510
loaded_params.add(name)
455511
break
456512
else:
457-
moe_loaded = self.load_moe_expert_weights(
458-
name,
459-
loaded_weight,
460-
params_dict,
461-
loaded_params,
462-
expert_params_mapping,
463-
fused=fused_experts_params)
464-
465-
if not moe_loaded:
513+
# First try to handle as expert weight
514+
is_expert_weight = False
515+
for mapping in expert_params_mapping:
516+
param_name, weight_name, expert_id, shard_id = mapping
517+
if weight_name not in name:
518+
continue
519+
520+
# Anyway, this is an expert weight and should not be
521+
# attempted to load as other weights later
522+
is_expert_weight = True
523+
524+
# Do not modify `name` since the loop may continue here
525+
# Instead, create a new variable
526+
name_mapped = name.replace(weight_name, param_name)
527+
528+
if is_pp_missing_parameter(name_mapped, self):
529+
continue
530+
531+
# Skip loading extra parameters for GPTQ/modelopt models.
532+
if ((name_mapped.endswith(".bias")
533+
or name_mapped.endswith("_bias"))
534+
and name_mapped not in params_dict):
535+
continue
536+
537+
param = params_dict[name_mapped]
538+
weight_loader = param.weight_loader
539+
weight_loader(param,
540+
loaded_weight,
541+
name_mapped,
542+
shard_id=shard_id,
543+
expert_id=expert_id)
544+
loaded_params.add(name_mapped)
545+
is_expert = True
546+
break
547+
else:
548+
# If we've identified this as an expert weight but couldn't
549+
# load it
550+
if is_expert_weight:
551+
# We've checked that this is an expert weight
552+
# However it's not mapped locally to this rank
553+
# So we simply skip it
554+
continue
555+
556+
# Not an expert weight, continue with regular loading
466557
if is_pp_missing_parameter(name, self):
467558
continue
468559

@@ -500,18 +591,20 @@ def load_weights(self, weights: Iterable[tuple[str,
500591
# Regular weight loader (handles both
501592
# param.weight_loader and default_weight_loader)
502593
weight_loader(param, loaded_weight)
594+
is_expert = True
503595
loaded_params.add(name)
504596
continue
505597

506598
param = params_dict[name]
507599
weight_loader = getattr(param, "weight_loader",
508600
default_weight_loader)
509601
weight_loader(param, loaded_weight)
602+
is_expert = False
510603
loaded_params.add(name)
511604
return loaded_params
512605

513606

514-
class Llama4ForCausalLM(LlamaForCausalLM):
607+
class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
515608

516609
packed_modules_mapping = {
517610
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -525,14 +618,57 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
525618
# enable temperature tuning by default when max_model_len > 32K
526619
default_attn_temperature_tuning = \
527620
vllm_config.model_config.max_model_len > 32768
528-
vllm_config.model_config.hf_config.attn_temperature_tuning \
529-
= gen_config.get(
530-
"attn_temperature_tuning", default_attn_temperature_tuning)
621+
vllm_config.model_config.hf_config.attn_temperature_tuning = \
622+
gen_config.get("attn_temperature_tuning",
623+
default_attn_temperature_tuning)
531624

532625
super().__init__(vllm_config=vllm_config,
533626
prefix=prefix,
534627
layer_type=Llama4DecoderLayer)
535628

629+
self.expert_weights = []
630+
631+
# Set MoE hyperparameters
632+
self.moe_layers: list[FusedMoE] = []
633+
for layer in self.model.layers:
634+
assert isinstance(layer, Llama4DecoderLayer)
635+
if isinstance(layer.feed_forward, Llama4MoE):
636+
self.moe_layers.append(layer.feed_forward.experts)
637+
638+
self.num_moe_layers = len(self.moe_layers)
639+
self.num_expert_groups = 1
640+
641+
example_moe = None
642+
for layer_idx in range(self.config.num_hidden_layers):
643+
layer = self.model.layers[layer_idx]
644+
if isinstance(layer.feed_forward, Llama4MoE):
645+
example_moe = layer.feed_forward
646+
break
647+
assert example_moe is not None
648+
649+
self.num_logical_experts = example_moe.n_logical_experts
650+
self.num_physical_experts = example_moe.n_physical_experts
651+
self.num_local_physical_experts = example_moe.n_local_physical_experts
652+
self.num_routed_experts = example_moe.n_routed_experts
653+
self.num_redundant_experts = example_moe.n_redundant_experts
654+
self.num_shared_experts = 1
655+
656+
def set_eplb_state(
657+
self,
658+
expert_load_view: torch.Tensor,
659+
logical_to_physical_map: torch.Tensor,
660+
logical_replica_count: torch.Tensor,
661+
) -> None:
662+
for layer_idx, layer in enumerate(self.moe_layers):
663+
# Register the expert weights.
664+
self.expert_weights.append(layer.get_expert_weights())
665+
layer.set_eplb_state(
666+
moe_layer_idx=layer_idx,
667+
expert_load_view=expert_load_view,
668+
logical_to_physical_map=logical_to_physical_map,
669+
logical_replica_count=logical_replica_count,
670+
)
671+
536672
def _init_model(self,
537673
vllm_config: VllmConfig,
538674
prefix: str = "",

0 commit comments

Comments
 (0)