File tree Expand file tree Collapse file tree 8 files changed +18
-17
lines changed Expand file tree Collapse file tree 8 files changed +18
-17
lines changed Original file line number Diff line number Diff line change @@ -84,6 +84,7 @@ def __init__(
84
84
head_dim : Optional [int ] = None ,
85
85
tie_word_embeddings : bool = False ,
86
86
is_quantized : bool = False ,
87
+ rms_norm_eps : float = 1e-5 ,
87
88
** kwargs ,
88
89
):
89
90
super ().__init__ (** kwargs )
@@ -123,6 +124,7 @@ def __init__(
123
124
self .dtype = dtype
124
125
self .tie_word_embeddings = tie_word_embeddings
125
126
self .is_quantized = is_quantized
127
+ self .rms_norm_eps = rms_norm_eps
126
128
127
129
128
130
@dataclass
Original file line number Diff line number Diff line change @@ -288,14 +288,14 @@ def __init__(
288
288
self .input_layernorm = RMSNorm (
289
289
fd_config ,
290
290
hidden_size = fd_config .model_config .hidden_size ,
291
- eps = 1e-5 ,
291
+ eps = fd_config . model_config . rms_norm_eps ,
292
292
prefix = f"{ prefix } .input_layernorm" ,
293
293
)
294
294
295
295
self .post_attention_layernorm = RMSNorm (
296
296
fd_config ,
297
297
hidden_size = fd_config .model_config .hidden_size ,
298
- eps = 1e-5 ,
298
+ eps = fd_config . model_config . rms_norm_eps ,
299
299
prefix = f"{ prefix } .post_attention_layernorm" ,
300
300
)
301
301
@@ -366,7 +366,7 @@ def __init__(
366
366
self .norm = RMSNorm (
367
367
fd_config ,
368
368
hidden_size = fd_config .model_config .hidden_size ,
369
- eps = 1e-5 ,
369
+ eps = fd_config . model_config . rms_norm_eps ,
370
370
prefix = f"{ fd_config .model_config .prefix_name } .norm" ,
371
371
)
372
372
Original file line number Diff line number Diff line change @@ -275,14 +275,14 @@ def __init__(
275
275
self .enorm = RMSNorm (
276
276
fd_config ,
277
277
hidden_size = fd_config .model_config .hidden_size ,
278
- eps = 1e-5 ,
278
+ eps = fd_config . model_config . rms_norm_eps ,
279
279
prefix = "ernie.mtp_emb_norm.0" ,
280
280
)
281
281
282
282
self .hnorm = RMSNorm (
283
283
fd_config ,
284
284
hidden_size = fd_config .model_config .hidden_size ,
285
- eps = 1e-5 ,
285
+ eps = fd_config . model_config . rms_norm_eps ,
286
286
prefix = "ernie.mtp_hidden_norm.0" ,
287
287
)
288
288
Original file line number Diff line number Diff line change @@ -273,14 +273,14 @@ def __init__(
273
273
self .input_layernorm = RMSNorm (
274
274
fd_config ,
275
275
hidden_size = fd_config .model_config .hidden_size ,
276
- eps = 1e-5 ,
276
+ eps = fd_config . model_config . rms_norm_eps ,
277
277
prefix = f"{ prefix } .input_layernorm" ,
278
278
)
279
279
280
280
self .post_attention_layernorm = RMSNorm (
281
281
fd_config ,
282
282
hidden_size = fd_config .model_config .hidden_size ,
283
- eps = 1e-5 ,
283
+ eps = fd_config . model_config . rms_norm_eps ,
284
284
prefix = f"{ prefix } .post_attention_layernorm" ,
285
285
)
286
286
@@ -358,7 +358,7 @@ def __init__(
358
358
self .norm = RMSNorm (
359
359
fd_config ,
360
360
hidden_size = fd_config .model_config .hidden_size ,
361
- eps = 1e-5 ,
361
+ eps = fd_config . model_config . rms_norm_eps ,
362
362
prefix = f"{ fd_config .model_config .prefix_name } .norm" ,
363
363
)
364
364
Original file line number Diff line number Diff line change @@ -161,14 +161,14 @@ def __init__(
161
161
self .input_layernorm = RMSNorm (
162
162
fd_config ,
163
163
hidden_size = fd_config .model_config .hidden_size ,
164
- eps = 1e-6 ,
164
+ eps = fd_config . model_config . rms_norm_eps ,
165
165
prefix = f"{ prefix } .input_layernorm" ,
166
166
)
167
167
168
168
self .post_attention_layernorm = RMSNorm (
169
169
fd_config ,
170
170
hidden_size = fd_config .model_config .hidden_size ,
171
- eps = 1e-6 ,
171
+ eps = fd_config . model_config . rms_norm_eps ,
172
172
prefix = f"{ prefix } .post_attention_layernorm" ,
173
173
)
174
174
@@ -248,7 +248,7 @@ def __init__(
248
248
self .norm = RMSNorm (
249
249
fd_config ,
250
250
hidden_size = fd_config .model_config .hidden_size ,
251
- eps = 1e-5 ,
251
+ eps = fd_config . model_config . rms_norm_eps ,
252
252
prefix = f"{ fd_config .model_config .prefix_name } .norm" ,
253
253
)
254
254
Original file line number Diff line number Diff line change @@ -79,12 +79,12 @@ def __init__(self,
79
79
80
80
self .q_norm = RMSNorm (fd_config = fd_config ,
81
81
hidden_size = fd_config .model_config .head_dim ,
82
- eps = 1e-6 ,
82
+ eps = fd_config . model_config . rms_norm_eps ,
83
83
prefix = f"{ prefix } .q_norm" ,
84
84
begin_norm_axis = 2 )
85
85
self .k_norm = RMSNorm (fd_config = fd_config ,
86
86
hidden_size = fd_config .model_config .head_dim ,
87
- eps = 1e-6 ,
87
+ eps = fd_config . model_config . rms_norm_eps ,
88
88
prefix = f"{ prefix } .k_norm" ,
89
89
begin_norm_axis = 2 )
90
90
@@ -183,7 +183,7 @@ def __init__(
183
183
self .norm = RMSNorm (
184
184
fd_config ,
185
185
hidden_size = fd_config .model_config .hidden_size ,
186
- eps = 1e-6 ,
186
+ eps = fd_config . model_config . rms_norm_eps ,
187
187
prefix = f"{ fd_config .model_config .prefix_name } .norm" ,
188
188
)
189
189
Original file line number Diff line number Diff line change @@ -121,12 +121,12 @@ def __init__(self,
121
121
122
122
self .q_norm = RMSNorm (fd_config ,
123
123
hidden_size = self .head_dim ,
124
- eps = 1e-6 ,
124
+ eps = fd_config . model_config . rms_norm_eps ,
125
125
prefix = f"{ prefix } .q_norm" ,
126
126
begin_norm_axis = 2 )
127
127
self .k_norm = RMSNorm (fd_config ,
128
128
hidden_size = self .head_dim ,
129
- eps = 1e-6 ,
129
+ eps = fd_config . model_config . rms_norm_eps ,
130
130
prefix = f"{ prefix } .k_norm" ,
131
131
begin_norm_axis = 2 )
132
132
Original file line number Diff line number Diff line change @@ -594,7 +594,6 @@ def initialize_fd_config(config_or_args) -> FDConfig:
594
594
model_config_dict , _ = ModelConfig .get_config_dict (config_or_args .model_name_or_path )
595
595
596
596
597
-
598
597
# Handle MoE related configs
599
598
if 'num_experts' in model_config_dict :
600
599
model_config_dict ['moe_num_experts' ] = model_config_dict .pop ('num_experts' )
You can’t perform that action at this time.
0 commit comments