18
18
# limitations under the License.
19
19
"""Inference-only LLaMA model compatible with HuggingFace weights."""
20
20
from collections .abc import Iterable
21
- from typing import Any , Optional
21
+ from typing import Any , Optional , Union , cast
22
22
23
23
import torch
24
24
from torch import nn
25
25
from transformers import Llama4TextConfig
26
26
27
27
from vllm .attention import Attention
28
28
from vllm .compilation .decorators import support_torch_compile
29
- from vllm .config import CacheConfig , VllmConfig
30
- from vllm .distributed import get_tensor_model_parallel_world_size
29
+ from vllm .config import CacheConfig , VllmConfig , get_current_vllm_config
30
+ from vllm .distributed import get_ep_group , get_tensor_model_parallel_world_size
31
31
from vllm .model_executor .layers .fused_moe import FusedMoE
32
32
from vllm .model_executor .layers .layernorm import RMSNorm
33
33
from vllm .model_executor .layers .linear import (QKVParallelLinear ,
38
38
from vllm .model_executor .model_loader .weight_utils import (
39
39
default_weight_loader , maybe_remap_kv_scale_name )
40
40
41
+ from .interfaces import MixtureOfExperts
41
42
from .llama import LlamaForCausalLM , LlamaMLP , LlamaModel
42
43
from .utils import (AutoWeightsLoader , extract_layer_index , fast_topk ,
43
44
is_pp_missing_parameter )
@@ -57,21 +58,46 @@ def custom_routing_function(
57
58
router_scores = torch .sigmoid (router_scores .float ())
58
59
return (router_scores , router_indices .to (torch .int32 ))
59
60
60
- def __init__ (self ,
61
- config : Llama4TextConfig ,
62
- quant_config : Optional [QuantizationConfig ] = None ,
63
- prefix : str = "" ):
61
+ def __init__ (
62
+ self ,
63
+ config : Llama4TextConfig ,
64
+ quant_config : Optional [QuantizationConfig ] = None ,
65
+ prefix : str = "" ,
66
+ enable_eplb : bool = False ,
67
+ ):
64
68
super ().__init__ ()
65
69
self .tp_size = get_tensor_model_parallel_world_size ()
66
70
self .top_k = config .num_experts_per_tok
67
71
72
+ self .ep_group = get_ep_group ().device_group
73
+ self .ep_rank = self .ep_group .rank ()
74
+ self .ep_size = self .ep_group .size ()
75
+ self .n_routed_experts = config .num_local_experts
76
+
68
77
intermediate_size_moe = config .intermediate_size
69
78
self .router = ReplicatedLinear (config .hidden_size ,
70
79
config .num_local_experts ,
71
80
bias = False ,
72
81
quant_config = None ,
73
82
prefix = f"{ prefix } .router" )
74
83
84
+ # Load balancing
85
+
86
+ vllm_config = get_current_vllm_config ()
87
+ parallel_config = vllm_config .parallel_config
88
+ self .enable_eplb = enable_eplb
89
+
90
+ self .n_logical_experts = self .n_routed_experts
91
+ self .n_redundant_experts = parallel_config .num_redundant_experts
92
+ self .n_physical_experts = (self .n_logical_experts +
93
+ self .n_redundant_experts )
94
+ self .n_local_physical_experts = self .n_physical_experts // self .ep_size
95
+
96
+ self .physical_expert_start = (self .ep_rank *
97
+ self .n_local_physical_experts )
98
+ self .physical_expert_end = (self .physical_expert_start +
99
+ self .n_local_physical_experts )
100
+
75
101
self .experts = FusedMoE (
76
102
num_experts = config .num_local_experts ,
77
103
top_k = config .num_experts_per_tok ,
@@ -82,7 +108,10 @@ def __init__(self,
82
108
reduce_results = False ,
83
109
renormalize = False ,
84
110
quant_config = quant_config ,
85
- prefix = f"{ prefix } .experts" )
111
+ prefix = f"{ prefix } .experts" ,
112
+ enable_eplb = enable_eplb ,
113
+ num_redundant_experts = self .n_redundant_experts ,
114
+ )
86
115
87
116
self .shared_expert = LlamaMLP (
88
117
hidden_size = config .hidden_size ,
@@ -229,7 +258,8 @@ def forward(
229
258
k = self .qk_norm (k .float ()).reshape (- 1 , self .kv_size ).to (k .dtype )
230
259
231
260
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
232
- # to NoPE layers, where the inference-time temperature tuning function
261
+ # to NoPE layers, where the inference-time temperature tuning
262
+ # function
233
263
# is customized to not affect short context
234
264
# while working at very long context
235
265
# https://arxiv.org/abs/2501.19399
@@ -252,6 +282,7 @@ def __init__(
252
282
cache_config : Optional [CacheConfig ] = None ,
253
283
quant_config : Optional [QuantizationConfig ] = None ,
254
284
prefix : str = "" ,
285
+ enable_eplb : bool = False ,
255
286
) -> None :
256
287
super ().__init__ ()
257
288
@@ -278,10 +309,11 @@ def __init__(
278
309
is_moe_layer = config .interleave_moe_layer_step > 0 and (
279
310
self .layer_idx + 1 ) % config .interleave_moe_layer_step == 0
280
311
if is_moe_layer :
281
- self .feed_forward = Llama4MoE (
312
+ self .feed_forward : Union [ Llama4MoE , LlamaMLP ] = Llama4MoE (
282
313
config = config ,
283
314
quant_config = quant_config ,
284
315
prefix = f"{ prefix } .feed_forward" ,
316
+ enable_eplb = enable_eplb ,
285
317
)
286
318
else :
287
319
self .feed_forward = LlamaMLP (
@@ -329,9 +361,26 @@ def __init__(self,
329
361
prefix : str = "" ,
330
362
layer_type : type [Llama4DecoderLayer ] = Llama4DecoderLayer ):
331
363
self .num_experts = vllm_config .model_config .hf_config .num_local_experts
364
+ self .num_redundant_experts = (
365
+ vllm_config .parallel_config .num_redundant_experts )
366
+ self .enable_eplb = vllm_config .parallel_config .enable_eplb
367
+
368
+ # We need to create layers with enable_eplb parameter
369
+ # Store the original layer_type and override it with a lambda
370
+ original_layer_type = layer_type
371
+
372
+ def create_layer (prefix ):
373
+ config = cast (Llama4TextConfig , vllm_config .model_config .hf_config )
374
+ return original_layer_type (config = config ,
375
+ cache_config = vllm_config .cache_config ,
376
+ quant_config = vllm_config .quant_config ,
377
+ prefix = prefix ,
378
+ enable_eplb = self .enable_eplb )
379
+
380
+ # Call parent init with our custom layer factory
332
381
super ().__init__ (vllm_config = vllm_config ,
333
382
prefix = prefix ,
334
- layer_type = layer_type )
383
+ layer_type = cast ( type [ nn . Module ], create_layer ) )
335
384
336
385
def load_moe_expert_weights (
337
386
self ,
@@ -370,8 +419,11 @@ def load_moe_expert_weights(
370
419
new_loaded_weight = new_loaded_weight .transpose (- 1 , - 2 )
371
420
layer_idx = extract_layer_index (name )
372
421
# EP mapping
373
- expert_map = self .layers [
374
- layer_idx ].feed_forward .experts .expert_map
422
+ feed_forward = self .layers [layer_idx ].feed_forward
423
+ if hasattr (feed_forward , 'experts' ):
424
+ expert_map = feed_forward .experts .expert_map
425
+ else :
426
+ expert_map = None
375
427
if expert_map is not None :
376
428
local_expert_indices = (expert_map != - 1 ) \
377
429
.nonzero () \
@@ -390,6 +442,7 @@ def load_moe_expert_weights(
390
442
391
443
loaded_params .add (full_param_name )
392
444
expert_param_loaded = True
445
+ is_expert = True
393
446
return expert_param_loaded
394
447
395
448
def load_weights (self , weights : Iterable [tuple [str ,
@@ -407,7 +460,9 @@ def load_weights(self, weights: Iterable[tuple[str,
407
460
ckpt_gate_proj_name = "gate_proj" ,
408
461
ckpt_down_proj_name = "down_proj" ,
409
462
ckpt_up_proj_name = "up_proj" ,
410
- num_experts = self .num_experts )
463
+ num_experts = self .num_experts ,
464
+ num_redundant_experts = self .num_redundant_experts ,
465
+ )
411
466
expert_params_mapping_fused = FusedMoE .make_expert_params_mapping (
412
467
ckpt_gate_proj_name = "gate_up_proj" ,
413
468
ckpt_down_proj_name = "down_proj" ,
@@ -451,18 +506,54 @@ def load_weights(self, weights: Iterable[tuple[str,
451
506
weight_loader (param , loaded_weight )
452
507
else :
453
508
weight_loader (param , loaded_weight , shard_id )
509
+ is_expert = False
454
510
loaded_params .add (name )
455
511
break
456
512
else :
457
- moe_loaded = self .load_moe_expert_weights (
458
- name ,
459
- loaded_weight ,
460
- params_dict ,
461
- loaded_params ,
462
- expert_params_mapping ,
463
- fused = fused_experts_params )
464
-
465
- if not moe_loaded :
513
+ # First try to handle as expert weight
514
+ is_expert_weight = False
515
+ for mapping in expert_params_mapping :
516
+ param_name , weight_name , expert_id , shard_id = mapping
517
+ if weight_name not in name :
518
+ continue
519
+
520
+ # Anyway, this is an expert weight and should not be
521
+ # attempted to load as other weights later
522
+ is_expert_weight = True
523
+
524
+ # Do not modify `name` since the loop may continue here
525
+ # Instead, create a new variable
526
+ name_mapped = name .replace (weight_name , param_name )
527
+
528
+ if is_pp_missing_parameter (name_mapped , self ):
529
+ continue
530
+
531
+ # Skip loading extra parameters for GPTQ/modelopt models.
532
+ if ((name_mapped .endswith (".bias" )
533
+ or name_mapped .endswith ("_bias" ))
534
+ and name_mapped not in params_dict ):
535
+ continue
536
+
537
+ param = params_dict [name_mapped ]
538
+ weight_loader = param .weight_loader
539
+ weight_loader (param ,
540
+ loaded_weight ,
541
+ name_mapped ,
542
+ shard_id = shard_id ,
543
+ expert_id = expert_id )
544
+ loaded_params .add (name_mapped )
545
+ is_expert = True
546
+ break
547
+ else :
548
+ # If we've identified this as an expert weight but couldn't
549
+ # load it
550
+ if is_expert_weight :
551
+ # We've checked that this is an expert weight
552
+ # However it's not mapped locally to this rank
553
+ # So we simply skip it
554
+ continue
555
+
556
+ # Not an expert weight, continue with regular loading
466
557
if is_pp_missing_parameter (name , self ):
467
558
continue
468
559
@@ -500,18 +591,20 @@ def load_weights(self, weights: Iterable[tuple[str,
500
591
# Regular weight loader (handles both
501
592
# param.weight_loader and default_weight_loader)
502
593
weight_loader (param , loaded_weight )
594
+ is_expert = True
503
595
loaded_params .add (name )
504
596
continue
505
597
506
598
param = params_dict [name ]
507
599
weight_loader = getattr (param , "weight_loader" ,
508
600
default_weight_loader )
509
601
weight_loader (param , loaded_weight )
602
+ is_expert = False
510
603
loaded_params .add (name )
511
604
return loaded_params
512
605
513
606
514
- class Llama4ForCausalLM (LlamaForCausalLM ):
607
+ class Llama4ForCausalLM (LlamaForCausalLM , MixtureOfExperts ):
515
608
516
609
packed_modules_mapping = {
517
610
"qkv_proj" : ["q_proj" , "k_proj" , "v_proj" ],
@@ -525,14 +618,57 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
525
618
# enable temperature tuning by default when max_model_len > 32K
526
619
default_attn_temperature_tuning = \
527
620
vllm_config .model_config .max_model_len > 32768
528
- vllm_config .model_config .hf_config .attn_temperature_tuning \
529
- = gen_config .get (
530
- "attn_temperature_tuning" , default_attn_temperature_tuning )
621
+ vllm_config .model_config .hf_config .attn_temperature_tuning = \
622
+ gen_config .get ("attn_temperature_tuning" ,
623
+ default_attn_temperature_tuning )
531
624
532
625
super ().__init__ (vllm_config = vllm_config ,
533
626
prefix = prefix ,
534
627
layer_type = Llama4DecoderLayer )
535
628
629
+ self .expert_weights = []
630
+
631
+ # Set MoE hyperparameters
632
+ self .moe_layers : list [FusedMoE ] = []
633
+ for layer in self .model .layers :
634
+ assert isinstance (layer , Llama4DecoderLayer )
635
+ if isinstance (layer .feed_forward , Llama4MoE ):
636
+ self .moe_layers .append (layer .feed_forward .experts )
637
+
638
+ self .num_moe_layers = len (self .moe_layers )
639
+ self .num_expert_groups = 1
640
+
641
+ example_moe = None
642
+ for layer_idx in range (self .config .num_hidden_layers ):
643
+ layer = self .model .layers [layer_idx ]
644
+ if isinstance (layer .feed_forward , Llama4MoE ):
645
+ example_moe = layer .feed_forward
646
+ break
647
+ assert example_moe is not None
648
+
649
+ self .num_logical_experts = example_moe .n_logical_experts
650
+ self .num_physical_experts = example_moe .n_physical_experts
651
+ self .num_local_physical_experts = example_moe .n_local_physical_experts
652
+ self .num_routed_experts = example_moe .n_routed_experts
653
+ self .num_redundant_experts = example_moe .n_redundant_experts
654
+ self .num_shared_experts = 1
655
+
656
+ def set_eplb_state (
657
+ self ,
658
+ expert_load_view : torch .Tensor ,
659
+ logical_to_physical_map : torch .Tensor ,
660
+ logical_replica_count : torch .Tensor ,
661
+ ) -> None :
662
+ for layer_idx , layer in enumerate (self .moe_layers ):
663
+ # Register the expert weights.
664
+ self .expert_weights .append (layer .get_expert_weights ())
665
+ layer .set_eplb_state (
666
+ moe_layer_idx = layer_idx ,
667
+ expert_load_view = expert_load_view ,
668
+ logical_to_physical_map = logical_to_physical_map ,
669
+ logical_replica_count = logical_replica_count ,
670
+ )
671
+
536
672
def _init_model (self ,
537
673
vllm_config : VllmConfig ,
538
674
prefix : str = "" ,
0 commit comments