@@ -193,28 +193,43 @@ def _qkv_refactor(self, pname, p, hf_layer):
193
193
wk = self .hf_model [hf_wk_name ]
194
194
wv = self .hf_model [hf_wv_name ]
195
195
196
- hidden_size = wq .shape [0 ]
197
- per_partition_size , start_index , end_index = compute_partition_range (
198
- hidden_size , self .tp_rank , self .tp_size )
199
- hidden_size_per_attention_head = divide (hidden_size ,
196
+ query_hidden_size = wq .shape [0 ]
197
+ kv_hidden_size = wk .shape [0 ]
198
+
199
+ per_partition_size , start_qindex , end_index = compute_partition_range (
200
+ query_hidden_size , self .tp_rank , self .tp_size )
201
+ _ ,start_kvindex , _ = compute_partition_range (
202
+ kv_hidden_size , self .tp_rank , self .tp_size )
203
+
204
+ hidden_size_per_attention_head = divide (query_hidden_size ,
200
205
self .config .num_attention_heads )
201
206
num_attention_heads_per_partition = divide (self .config .num_attention_heads ,
202
207
self .tp_size )
203
208
204
- new_w = torch .zeros ((per_partition_size * 3 , wq .shape [1 ]), dtype = wq .dtype )
209
+ num_kv_heads_per_partition = divide (self .config .num_key_value_heads ,
210
+ self .tp_size )
211
+ qkv_size = (num_attention_heads_per_partition + 2 * num_kv_heads_per_partition )* hidden_size_per_attention_head
212
+ num_qheads_per_group = divide (self .config .num_attention_heads ,self .config .num_key_value_heads )
213
+ num_groups = divide (num_attention_heads_per_partition ,num_qheads_per_group )
214
+ new_w = torch .zeros ((qkv_size , wq .shape [1 ]), dtype = wq .dtype )
215
+
216
+ for i in range (num_groups ):
217
+ query_current_index = start_qindex + i * num_qheads_per_group * hidden_size_per_attention_head
218
+ query_next_index = query_current_index + num_qheads_per_group * hidden_size_per_attention_head
219
+ kv_current_index = start_kvindex + i * hidden_size_per_attention_head
220
+ kv_next_kvindex = kv_current_index + hidden_size_per_attention_head
221
+
222
+ new_w_index = i * (num_qheads_per_group + 2 )* hidden_size_per_attention_head
205
223
206
- for i in range (num_attention_heads_per_partition ):
207
- current_index = start_index + i * hidden_size_per_attention_head
208
- next_index = current_index + hidden_size_per_attention_head
209
- new_w_index = i * (3 * hidden_size_per_attention_head )
210
- new_w [new_w_index : new_w_index + (3 * hidden_size_per_attention_head ), :] = \
224
+ new_w [new_w_index :new_w_index + (num_qheads_per_group + 2 )* hidden_size_per_attention_head ,:]= \
211
225
torch .cat ([
212
- wq [current_index : next_index , :],
213
- wk [current_index : next_index , :],
214
- wv [current_index : next_index , :]
215
- ], dim = 0 )
226
+ wq [query_current_index :query_next_index ,:],
227
+ wk [kv_current_index :kv_next_kvindex ,:],
228
+ wv [kv_current_index :kv_next_kvindex ,:]
229
+ ],dim = 0 )
230
+
216
231
self .record_mapping_info (
217
- f"mega-ds:{ pname ,p .data .shape } <--hf{ hf_wq_name ,hf_wk_name ,hf_wv_name ,} cat q,k,v [{ current_index } :{ next_index } ,:] of q,k,v{ wq .shape } "
232
+ f"mega-ds:{ pname ,p .data .shape } <--hf{ hf_wq_name ,hf_wk_name ,hf_wv_name ,} cat q,k,v [{ query_current_index } :{ query_next_index } ,:] of q,k,v{ wq .shape } "
218
233
)
219
234
return new_w
220
235
@@ -383,17 +398,18 @@ def _qkv_refactor_to_hf(self, pname, ds_w, hf_layer):
383
398
hidden_size = oldshape [- 1 ]
384
399
hidden_size_per_attention_head = divide (hidden_size ,
385
400
self .config .num_attention_heads )
386
- num_attention_heads_per_partition = divide ( self . config . num_attention_heads ,
387
- self .tp_size )
388
- newshape = (self .tp_size , num_attention_heads_per_partition , 3 , hidden_size_per_attention_head , hidden_size )
401
+ # MHA & GQA
402
+ group = divide ( self . config . num_attention_heads , self .config . num_key_value_heads )
403
+ newshape = (self .config . num_key_value_heads , group + 2 , hidden_size_per_attention_head , hidden_size )
389
404
ds_w_out = ds_w_all_rank .reshape (* newshape )
390
- self .hf_dict [hf_q_name ] = copy .deepcopy (ds_w_out [:, :, 0 , :, :].reshape (- 1 , oldshape [- 1 ]))
391
- self .hf_dict [hf_k_name ] = copy .deepcopy (ds_w_out [:, :, 1 , :, :].reshape (- 1 , oldshape [- 1 ]))
392
- self .hf_dict [hf_v_name ] = copy .deepcopy (ds_w_out [:, :, 2 , :, :].reshape (- 1 , oldshape [- 1 ]))
405
+ query_weight , key_weight , value_weight = torch .split (ds_w_out , [group , 1 , 1 ], dim = 1 )
406
+ self .hf_dict [hf_q_name ] = copy .deepcopy (query_weight .reshape (- 1 , hidden_size ))
407
+ self .hf_dict [hf_k_name ] = copy .deepcopy (key_weight .reshape (- 1 , hidden_size ))
408
+ self .hf_dict [hf_v_name ] = copy .deepcopy (value_weight .reshape (- 1 , hidden_size ))
409
+ del query_weight , key_weight , value_weight
393
410
394
411
395
412
def transform_from_megads_to_hf (self ):
396
- use_gqa = True if self .num_attention_heads != self .num_key_value_heads else False
397
413
398
414
for pname , p in self .ds_model .named_parameters ():
399
415
if pname in [
@@ -411,11 +427,7 @@ def transform_from_megads_to_hf(self):
411
427
subname = mobj .group (2 )
412
428
hf_layer = layer_num - self .offset_num
413
429
if subname in ["self_attention.query_key_value.weight" ]:
414
- if not use_gqa :
415
- self ._qkv_refactor_to_hf (pname , p , hf_layer )
416
- else :
417
- #TODO(billishyahao): Not impl yet ...
418
- assert False
430
+ self ._qkv_refactor_to_hf (pname , p , hf_layer )
419
431
elif subname in ["mlp.dense_h_to_4h.weight" ]:
420
432
self ._mlphto4h_dense_refactor_to_hf (pname , p , hf_layer )
421
433
elif subname in [
0 commit comments