Skip to content

Commit c3df187

Browse files
authored
[tool]GQA convert support (#454)
* [tools]GQA convert support * fix readme
1 parent 676a482 commit c3df187

File tree

5 files changed

+62
-31
lines changed

5 files changed

+62
-31
lines changed

examples_deepspeed/finetune_hf_llama/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ The pre-trained weights can be found at [Hugging Face - LLAMA-7B](https://huggin
1010

1111
#### 1. Converting Hugging Face Model Weights to Megatron-Deepspeed Model
1212
```bash
13-
bash examples_deepspeed/finetune_hf_llama/finetune_llama.sh convert
13+
bash examples_deepspeed/finetune_hf_llama/finetune_llama.sh convert_hf2mds
1414
```
15-
This command writes the Hugging Face model weights into the Megatron-Deepspeed model and saves it. You can adjust the parallel configuration in the script.
15+
This command writes the Hugging Face model weights into the Megatron-Deepspeed model and saves it. You can adjust the parallel configuration in the script.```convert_mds2hf``` can convert a Megatron-Deepspeed model into the Hugging Face format
1616

1717
#### 2. Fine-tuning Process
1818
```bash
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
{
22
"train_batch_size" : 256,
33
"train_micro_batch_size_per_gpu": 16,
4-
"steps_per_print": 1
4+
"steps_per_print": 100,
5+
"zero_optimization": {
6+
"stage": 0
7+
},
8+
"bf16": {
9+
"enabled": true
10+
}
511
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"train_batch_size" : 256,
3+
"train_micro_batch_size_per_gpu": 16,
4+
"steps_per_print": 100
5+
}

examples_deepspeed/finetune_hf_llama/finetune_llama.sh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ cat <<EOT > $DS_CONFIG
4343
}
4444
EOT
4545

46+
if [ "$1" = "convert_hf2mds" ]; then
47+
DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config_empty.json"
48+
elif [ "$1" = "convert_mds2hf" ]; then
49+
DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config_empty.json"
50+
else
51+
DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config.json"
52+
fi
4653

4754
covert_hf2mds_args="deepspeed tools/hf2megads_weight_converter.py \
4855
--hf-ckpt-num-shards 2 \
@@ -69,6 +76,7 @@ comm_args="--tensor-model-parallel-size $TP \
6976
--num-layers $NUM_LAYERS \
7077
--hidden-size $HIDDEN_SIZE \
7178
--num-attention-heads $NUM_HEADS \
79+
--finetune \
7280
--ffn-hidden-size $FFN_HIDDEN_SIZE \
7381
--attention-dropout 0 \
7482
--hidden-dropout 0 \
@@ -97,7 +105,7 @@ comm_args="--tensor-model-parallel-size $TP \
97105
--zero-stage 0 \
98106
--tokenizer-type HFTokenizer \
99107
--tokenizer-model $HF_LLAMA_PATH \
100-
--deepspeed_config ./examples_deepspeed/finetune_hf_llama/ds_config.json \
108+
--deepspeed_config $DS_CONFIG_PATH \
101109
--deepspeed \
102110
--distributed-backend nccl \
103111
--num-workers 0 \

tools/hf2megads_weight_converter.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -193,28 +193,43 @@ def _qkv_refactor(self, pname, p, hf_layer):
193193
wk = self.hf_model[hf_wk_name]
194194
wv = self.hf_model[hf_wv_name]
195195

196-
hidden_size = wq.shape[0]
197-
per_partition_size, start_index, end_index = compute_partition_range(
198-
hidden_size, self.tp_rank, self.tp_size)
199-
hidden_size_per_attention_head = divide(hidden_size,
196+
query_hidden_size = wq.shape[0]
197+
kv_hidden_size = wk.shape[0]
198+
199+
per_partition_size, start_qindex, end_index = compute_partition_range(
200+
query_hidden_size, self.tp_rank, self.tp_size)
201+
_,start_kvindex, _= compute_partition_range(
202+
kv_hidden_size, self.tp_rank, self.tp_size)
203+
204+
hidden_size_per_attention_head = divide(query_hidden_size,
200205
self.config.num_attention_heads)
201206
num_attention_heads_per_partition = divide(self.config.num_attention_heads,
202207
self.tp_size)
203208

204-
new_w = torch.zeros((per_partition_size * 3, wq.shape[1]), dtype=wq.dtype)
209+
num_kv_heads_per_partition= divide(self.config.num_key_value_heads,
210+
self.tp_size)
211+
qkv_size=(num_attention_heads_per_partition+2*num_kv_heads_per_partition)*hidden_size_per_attention_head
212+
num_qheads_per_group=divide(self.config.num_attention_heads,self.config.num_key_value_heads)
213+
num_groups =divide(num_attention_heads_per_partition,num_qheads_per_group)
214+
new_w = torch.zeros((qkv_size, wq.shape[1]), dtype=wq.dtype)
215+
216+
for i in range(num_groups):
217+
query_current_index=start_qindex+i*num_qheads_per_group*hidden_size_per_attention_head
218+
query_next_index=query_current_index+num_qheads_per_group*hidden_size_per_attention_head
219+
kv_current_index=start_kvindex+i*hidden_size_per_attention_head
220+
kv_next_kvindex=kv_current_index+hidden_size_per_attention_head
221+
222+
new_w_index=i* (num_qheads_per_group+2)*hidden_size_per_attention_head
205223

206-
for i in range(num_attention_heads_per_partition):
207-
current_index = start_index + i * hidden_size_per_attention_head
208-
next_index = current_index + hidden_size_per_attention_head
209-
new_w_index = i * (3 * hidden_size_per_attention_head)
210-
new_w[new_w_index: new_w_index + (3 * hidden_size_per_attention_head), :] = \
224+
new_w[new_w_index:new_w_index+(num_qheads_per_group+2)*hidden_size_per_attention_head,:]=\
211225
torch.cat([
212-
wq[current_index: next_index, :],
213-
wk[current_index: next_index, :],
214-
wv[current_index: next_index, :]
215-
], dim=0)
226+
wq[query_current_index:query_next_index,:],
227+
wk[kv_current_index:kv_next_kvindex,:],
228+
wv[kv_current_index:kv_next_kvindex,:]
229+
],dim=0)
230+
216231
self.record_mapping_info(
217-
f"mega-ds:{pname,p.data.shape}<--hf{hf_wq_name,hf_wk_name,hf_wv_name,} cat q,k,v [{current_index}:{next_index},:] of q,k,v{wq.shape}"
232+
f"mega-ds:{pname,p.data.shape}<--hf{hf_wq_name,hf_wk_name,hf_wv_name,} cat q,k,v [{query_current_index}:{query_next_index},:] of q,k,v{wq.shape}"
218233
)
219234
return new_w
220235

@@ -383,17 +398,18 @@ def _qkv_refactor_to_hf(self, pname, ds_w, hf_layer):
383398
hidden_size = oldshape[-1]
384399
hidden_size_per_attention_head = divide(hidden_size,
385400
self.config.num_attention_heads)
386-
num_attention_heads_per_partition = divide(self.config.num_attention_heads,
387-
self.tp_size)
388-
newshape = (self.tp_size, num_attention_heads_per_partition, 3, hidden_size_per_attention_head, hidden_size)
401+
# MHA & GQA
402+
group = divide(self.config.num_attention_heads, self.config.num_key_value_heads)
403+
newshape = (self.config.num_key_value_heads, group + 2, hidden_size_per_attention_head, hidden_size)
389404
ds_w_out = ds_w_all_rank.reshape(*newshape)
390-
self.hf_dict[hf_q_name] = copy.deepcopy(ds_w_out[:, :, 0, :, :].reshape(-1, oldshape[-1]))
391-
self.hf_dict[hf_k_name] = copy.deepcopy(ds_w_out[:, :, 1, :, :].reshape(-1, oldshape[-1]))
392-
self.hf_dict[hf_v_name] = copy.deepcopy(ds_w_out[:, :, 2, :, :].reshape(-1, oldshape[-1]))
405+
query_weight, key_weight, value_weight = torch.split(ds_w_out, [group, 1, 1], dim=1)
406+
self.hf_dict[hf_q_name] = copy.deepcopy(query_weight.reshape(-1, hidden_size))
407+
self.hf_dict[hf_k_name] = copy.deepcopy(key_weight.reshape(-1, hidden_size))
408+
self.hf_dict[hf_v_name] = copy.deepcopy(value_weight.reshape(-1, hidden_size))
409+
del query_weight, key_weight, value_weight
393410

394411

395412
def transform_from_megads_to_hf(self):
396-
use_gqa = True if self.num_attention_heads != self.num_key_value_heads else False
397413

398414
for pname, p in self.ds_model.named_parameters():
399415
if pname in [
@@ -411,11 +427,7 @@ def transform_from_megads_to_hf(self):
411427
subname = mobj.group(2)
412428
hf_layer = layer_num - self.offset_num
413429
if subname in ["self_attention.query_key_value.weight"]:
414-
if not use_gqa:
415-
self._qkv_refactor_to_hf(pname, p, hf_layer)
416-
else:
417-
#TODO(billishyahao): Not impl yet ...
418-
assert False
430+
self._qkv_refactor_to_hf(pname, p, hf_layer)
419431
elif subname in ["mlp.dense_h_to_4h.weight"]:
420432
self._mlphto4h_dense_refactor_to_hf(pname, p, hf_layer)
421433
elif subname in [

0 commit comments

Comments
 (0)