Skip to content

Commit 5210c3f

Browse files
committed
[P0] Fixing trainer saving due to FSDP integration
1 parent 5a36985 commit 5210c3f

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

pyreft/reft_trainer.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,33 @@ def make_dataloader(
6666

6767
class ReftTrainer(Trainer):
6868
def save_model(self, output_dir, _internal_call=False):
69-
if dist.get_rank() == 0:
70-
if not os.path.exists(output_dir):
71-
os.makedirs(output_dir)
72-
self.model.save_intervention(
73-
save_directory=f"{output_dir}/intervenable_model",
74-
include_model=True
75-
)
69+
# Handle CPU training and non-distributed cases
70+
try:
71+
is_main_process = not dist.is_initialized() or dist.get_rank() == 0
72+
except (RuntimeError, AttributeError) as e: # Catches case when torch.distributed is not available or other dist errors
73+
logger.error(f"Error checking distributed training status: {str(e)}")
74+
is_main_process = True
75+
76+
if is_main_process:
77+
target_dir = f"{output_dir}/intervenable_model"
78+
# Log warning if target directory exists and has content
79+
if os.path.exists(target_dir) and os.listdir(target_dir):
80+
logger.warning(
81+
f"Directory {target_dir} already exists and contains files. "
82+
"Skipping save to prevent overwriting existing model."
83+
)
84+
return
85+
86+
try:
87+
if not os.path.exists(output_dir):
88+
os.makedirs(output_dir)
89+
self.model.save_intervention(
90+
save_directory=target_dir,
91+
include_model=True
92+
)
93+
except Exception as e:
94+
logger.error(f"Error saving model to {target_dir}: {str(e)}")
95+
raise # Re-raise the exception after logging
7696

7797
def _load_best_model(self):
7898
logger.warning(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")

0 commit comments

Comments
 (0)