@@ -62,14 +62,20 @@ def __init__(
62
62
63
63
def save_checkpoint (self , model , optimizer , args , epoch , model_ema = None , metric = None , use_amp = False ):
64
64
assert epoch >= 0
65
+ tmp_save_path = os .path .join (self .checkpoint_dir , 'tmp' + self .extension )
66
+ last_save_path = os .path .join (self .checkpoint_dir , 'last' + self .extension )
67
+ self ._save (tmp_save_path , model , optimizer , args , epoch , model_ema , metric , use_amp )
68
+ if os .path .exists (last_save_path ):
69
+ os .unlink (last_save_path ) # required for Windows support.
70
+ os .rename (tmp_save_path , last_save_path )
65
71
worst_file = self .checkpoint_files [- 1 ] if self .checkpoint_files else None
66
72
if (len (self .checkpoint_files ) < self .max_history
67
73
or metric is None or self .cmp (metric , worst_file [1 ])):
68
74
if len (self .checkpoint_files ) >= self .max_history :
69
75
self ._cleanup_checkpoints (1 )
70
76
filename = '-' .join ([self .save_prefix , str (epoch )]) + self .extension
71
77
save_path = os .path .join (self .checkpoint_dir , filename )
72
- self . _save ( save_path , model , optimizer , args , epoch , model_ema , metric , use_amp )
78
+ os . link ( last_save_path , save_path )
73
79
self .checkpoint_files .append ((save_path , metric ))
74
80
self .checkpoint_files = sorted (
75
81
self .checkpoint_files , key = lambda x : x [1 ],
@@ -83,7 +89,10 @@ def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=
83
89
if metric is not None and (self .best_metric is None or self .cmp (metric , self .best_metric )):
84
90
self .best_epoch = epoch
85
91
self .best_metric = metric
86
- shutil .copyfile (save_path , os .path .join (self .checkpoint_dir , 'model_best' + self .extension ))
92
+ best_save_path = os .path .join (self .checkpoint_dir , 'model_best' + self .extension )
93
+ if os .path .exists (best_save_path ):
94
+ os .unlink (best_save_path )
95
+ os .link (last_save_path , best_save_path )
87
96
88
97
return (None , None ) if self .best_metric is None else (self .best_metric , self .best_epoch )
89
98
0 commit comments