@@ -66,13 +66,33 @@ def make_dataloader(
66
66
67
67
class ReftTrainer (Trainer ):
68
68
def save_model (self , output_dir , _internal_call = False ):
69
- if dist .get_rank () == 0 :
70
- if not os .path .exists (output_dir ):
71
- os .makedirs (output_dir )
72
- self .model .save_intervention (
73
- save_directory = f"{ output_dir } /intervenable_model" ,
74
- include_model = True
75
- )
69
+ # Handle CPU training and non-distributed cases
70
+ try :
71
+ is_main_process = not dist .is_initialized () or dist .get_rank () == 0
72
+ except (RuntimeError , AttributeError ) as e : # Catches case when torch.distributed is not available or other dist errors
73
+ logger .error (f"Error checking distributed training status: { str (e )} " )
74
+ is_main_process = True
75
+
76
+ if is_main_process :
77
+ target_dir = f"{ output_dir } /intervenable_model"
78
+ # Log warning if target directory exists and has content
79
+ if os .path .exists (target_dir ) and os .listdir (target_dir ):
80
+ logger .warning (
81
+ f"Directory { target_dir } already exists and contains files. "
82
+ "Skipping save to prevent overwriting existing model."
83
+ )
84
+ return
85
+
86
+ try :
87
+ if not os .path .exists (output_dir ):
88
+ os .makedirs (output_dir )
89
+ self .model .save_intervention (
90
+ save_directory = target_dir ,
91
+ include_model = True
92
+ )
93
+ except Exception as e :
94
+ logger .error (f"Error saving model to { target_dir } : { str (e )} " )
95
+ raise # Re-raise the exception after logging
76
96
77
97
def _load_best_model (self ):
78
98
logger .warning (f"Loading best model from { self .state .best_model_checkpoint } (score: { self .state .best_metric } )." )
0 commit comments