@@ -43,11 +43,38 @@ def load_ep_checkpoint(model_path: str,
43
43
filtered_map = {k : v for k , v in weight_list .items () if "experts" not in k }
44
44
num_local_ffn_keys = []
45
45
46
- for i in range (config .moe_layer_start_index , config .num_hidden_layers ):
47
- for j in range (
48
- config .num_experts_start_offset ,
49
- config .num_experts_start_offset + config .num_experts_per_rank ,
50
- ):
46
+ from itertools import chain
47
+ def get_expert_ranges (config ):
48
+ """
49
+ Generate expert index ranges based on configuration parameters
50
+
51
+ This function is primarily used in Mixture-of-Experts (MoE) models to generate
52
+ expert index ranges according to configuration parameters. When moe_num_experts
53
+ is a list in the config, it returns a chained combination of two ranges, otherwise
54
+ returns a single range.
55
+
56
+ Args:
57
+ config: Configuration object
58
+
59
+ Returns:
60
+ If moe_num_experts is a list:
61
+ Returns a chained combination (chain object) of two ranges:
62
+ 1. Base range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank)
63
+ 2. Offset range: [base_range.start + moe_num_experts[0], base_range.stop + moe_num_experts[0])
64
+ Else:
65
+ Returns single range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank)
66
+ """
67
+ base_range = range (
68
+ config .num_experts_start_offset ,
69
+ config .num_experts_start_offset + config .num_experts_per_rank
70
+ )
71
+ if isinstance (config .moe_num_experts , list ):
72
+ return chain (base_range ,
73
+ range (base_range .start + config .moe_num_experts [0 ], base_range .stop + config .moe_num_experts [0 ]))
74
+ return base_range
75
+
76
+ for i in range (config .moe_layer_start_index , config .num_layers ):
77
+ for j in get_expert_ranges (config ):
51
78
up_gate_proj_key = f"ernie.layers.{ i } .mlp.experts.{ j } .up_gate_proj.weight"
52
79
down_proj_key = (f"ernie.layers.{ i } .mlp.experts.{ j } .down_proj.weight" )
53
80
0 commit comments