@@ -164,7 +164,6 @@ def __init__(
164
164
165
165
self .num_layers = fd_config .model_config .num_layers
166
166
fd_config .model_config .prefix_name = "model"
167
- fd_config .model_config .tie_word_embeddings = True
168
167
169
168
self .embeddings = VocabParallelEmbedding (
170
169
fd_config = fd_config ,
@@ -240,14 +239,13 @@ def __init__(self, fd_config: FDConfig):
240
239
self .model = Qwen3Model (fd_config = fd_config )
241
240
242
241
self .ori_vocab_size = fd_config .model_config .ori_vocab_size
243
-
242
+ self . tie_word_embeddings = fd_config . model_config . tie_word_embeddings
244
243
self .lm_head = ParallelLMHead (
245
244
fd_config = fd_config ,
246
245
embedding_dim = fd_config .model_config .hidden_size ,
247
246
num_embeddings = fd_config .model_config .vocab_size ,
248
- prefix = ( f" { fd_config . model_config . prefix_name } .embed_tokens" ) ,
247
+ prefix = "lm_head" ,
249
248
)
250
- self .tie_word_embeddings = fd_config .model_config .tie_word_embeddings
251
249
252
250
@classmethod
253
251
def name (self ):
@@ -269,7 +267,8 @@ def set_state_dict(self, state_dict):
269
267
if self .tie_word_embeddings :
270
268
self .lm_head .out_linear .weight .set_value (
271
269
self .model .embeddings .word_embeddings .weight .transpose ([1 , 0 ]))
272
- self .lm_head .load_state_dict (state_dict )
270
+ else :
271
+ self .lm_head .load_state_dict (state_dict )
273
272
274
273
def compute_logits (self , hidden_states : paddle .Tensor ):
275
274
"""
@@ -324,6 +323,7 @@ def get_tensor_parallel_split_mappings(num_layers):
324
323
325
324
base_actions = {
326
325
# Row Linear
326
+ "lm_head.weight" : partial (fn , is_column = True ),
327
327
"embed_tokens.weight" : partial (fn , is_column = False ),
328
328
"layers.0.self_attn.o_proj.weight" : partial (fn ,
329
329
is_column = False ),
0 commit comments