From ded91f8f8543a1f785a96094a9122705198f9083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Moutet?= Date: Sat, 6 May 2023 18:21:57 +0200 Subject: [PATCH] Update convert_to_hf_gptneox.py Updating this conversion so it works for training made with only 1 GPU --- tools/convert_to_hf_gptneox.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tools/convert_to_hf_gptneox.py b/tools/convert_to_hf_gptneox.py index 751d092..0294e6e 100644 --- a/tools/convert_to_hf_gptneox.py +++ b/tools/convert_to_hf_gptneox.py @@ -54,18 +54,20 @@ def load_decentralized_checkpoint(model, checkpoint_path, n_stages=2, n_layer_pe # torch.save(_tmp, os.path.join(output_path, f'pytorch_{j}.pt')) model.gpt_neox.layers[j].load_state_dict(_tmp) - elif i == n_stages - 1: - for j in range(n_layer_per_stage): + if i != 0 and i == n_stages - 1 or n_stages == 1: + if n_stages != 1: + for j in range(n_layer_per_stage): + _tmp = {k[len(f"{j}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j}.")} + if len(_tmp) == 0: + break + # torch.save(_tmp, os.path.join(output_path, f'pytorch_{i*n_layer_per_stage + j}.pt')) + model.gpt_neox.layers[i*n_layer_per_stage + j].load_state_dict(_tmp) + if i*n_layer_per_stage + j == len(model.gpt_neox.layers) - 1: + j += 1 + break _tmp = {k[len(f"{j}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j}.")} - if len(_tmp) == 0: - break - # torch.save(_tmp, os.path.join(output_path, f'pytorch_{i*n_layer_per_stage + j}.pt')) - model.gpt_neox.layers[i*n_layer_per_stage + j].load_state_dict(_tmp) - if i*n_layer_per_stage + j == len(model.gpt_neox.layers) - 1: - j += 1 - break - - _tmp = {k[len(f"{j}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j}.")} + else: + _tmp = {k[len(f"{n_layer_per_stage+1}."):]:v for k,v in checkpoint.items() if k.startswith(f"{n_layer_per_stage+1}.")} #Added line if len(_tmp) == 0: break # torch.save(_tmp, os.path.join(output_path, f'pytorch_lm_head.pt')) @@ -75,7 +77,7 @@ def load_decentralized_checkpoint(model, checkpoint_path, n_stages=2, n_layer_pe if 'embed_out.bias' in _tmp: model.embed_out.bias.data[:] = _tmp['embed_out.bias'] - else: + elif i != 0: for j in range(n_layer_per_stage): _tmp = {k[len(f"{j}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j}.")} if len(_tmp) == 0: @@ -130,4 +132,3 @@ def load_decentralized_checkpoint(model, checkpoint_path, n_stages=2, n_layer_pe print(f'saved HF model to `{args.save_path}`') config.save_pretrained(args.save_path) tokenizer.save_pretrained(args.save_path) -