Skip to content

Commit 0d03403

Browse files
[Fix] Fix mm ep weight init. (#2855)
* fix_45t_mm * Update load_weight_utils.py * Update load_weight_utils.py
1 parent 0253381 commit 0d03403

File tree

1 file changed

+32
-5
lines changed

1 file changed

+32
-5
lines changed

fastdeploy/model_executor/load_weight_utils.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,38 @@ def load_ep_checkpoint(model_path: str,
4343
filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k}
4444
num_local_ffn_keys = []
4545

46-
for i in range(config.moe_layer_start_index, config.num_hidden_layers):
47-
for j in range(
48-
config.num_experts_start_offset,
49-
config.num_experts_start_offset + config.num_experts_per_rank,
50-
):
46+
from itertools import chain
47+
def get_expert_ranges(config):
48+
"""
49+
Generate expert index ranges based on configuration parameters
50+
51+
This function is primarily used in Mixture-of-Experts (MoE) models to generate
52+
expert index ranges according to configuration parameters. When moe_num_experts
53+
is a list in the config, it returns a chained combination of two ranges, otherwise
54+
returns a single range.
55+
56+
Args:
57+
config: Configuration object
58+
59+
Returns:
60+
If moe_num_experts is a list:
61+
Returns a chained combination (chain object) of two ranges:
62+
1. Base range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank)
63+
2. Offset range: [base_range.start + moe_num_experts[0], base_range.stop + moe_num_experts[0])
64+
Else:
65+
Returns single range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank)
66+
"""
67+
base_range = range(
68+
config.num_experts_start_offset,
69+
config.num_experts_start_offset + config.num_experts_per_rank
70+
)
71+
if isinstance(config.moe_num_experts, list):
72+
return chain(base_range,
73+
range(base_range.start + config.moe_num_experts[0], base_range.stop + config.moe_num_experts[0]))
74+
return base_range
75+
76+
for i in range(config.moe_layer_start_index, config.num_layers):
77+
for j in get_expert_ranges(config):
5178
up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
5279
down_proj_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")
5380

0 commit comments

Comments
 (0)