From b3bd8e7af732aa0b3ebda359f558d73f1eb34058 Mon Sep 17 00:00:00 2001 From: Eric Date: Mon, 27 Mar 2023 15:28:01 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=AF=8F=E4=B8=AAcheckpoint?= =?UTF-8?q?=E7=9A=84config=E6=96=87=E4=BB=B6=E7=9A=84=E5=AD=98=E5=82=A8?= =?UTF-8?q?=EF=BC=8C=E9=98=B2=E6=AD=A2=E5=8A=A0=E8=BD=BD=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- finetune.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/finetune.py b/finetune.py index 54bb1ca..888887f 100644 --- a/finetune.py +++ b/finetune.py @@ -111,7 +111,16 @@ def save_model(self, output_dir=None, _internal_call=False): k: v.to("cpu") for k, v in self.model.named_parameters() if v.requires_grad } torch.save(saved_params, os.path.join(output_dir, "adapter_model.bin")) - + if self.model.peft_config.base_model_name_or_path is None: + self.model.peft_config.base_model_name_or_path = ( + self.model.base_model.__dict__.get("name_or_path", None) + if isinstance(self.model.peft_config, PromptLearningConfig) + else self.model.base_model.model.__dict__.get("name_or_path", None) + ) + inference_mode = self.model.peft_config.inference_mode + self.model.peft_config.inference_mode = True + self.model.peft_config.save_pretrained(output_dir) + self.model.peft_config.inference_mode = inference_mode def main(): finetune_args, training_args = HfArgumentParser(