@@ -37,20 +37,67 @@ def unwrap_model(model):
37
37
return model .module if hasattr (model , 'module' ) else model
38
38
39
39
40
- def get_state_dict (model ):
41
- return unwrap_model (model ).state_dict ()
40
+ def get_state_dict (model , unwrap_fn = unwrap_model ):
41
+ return unwrap_fn (model ).state_dict ()
42
+
43
+
44
+ class ApexScaler :
45
+ state_dict_key = "amp"
46
+
47
+ def __call__ (self , loss , optimizer ):
48
+ with amp .scale_loss (loss , optimizer ) as scaled_loss :
49
+ scaled_loss .backward ()
50
+ optimizer .step ()
51
+
52
+ def state_dict (self ):
53
+ if 'state_dict' in amp .__dict__ :
54
+ return amp .state_dict ()
55
+
56
+ def load_state_dict (self , state_dict ):
57
+ if 'load_state_dict' in amp .__dict__ :
58
+ amp .load_state_dict (state_dict )
59
+
60
+
61
+ class NativeScaler :
62
+ state_dict_key = "amp_scaler"
63
+
64
+ def __init__ (self ):
65
+ self ._scaler = torch .cuda .amp .GradScaler ()
66
+
67
+ def __call__ (self , loss , optimizer ):
68
+ self ._scaler .scale (loss ).backward ()
69
+ self ._scaler .step (optimizer )
70
+ self ._scaler .update ()
71
+
72
+ def state_dict (self ):
73
+ return self ._scaler .state_dict ()
74
+
75
+ def load_state_dict (self , state_dict ):
76
+ self ._scaler .load_state_dict (state_dict )
42
77
43
78
44
79
class CheckpointSaver :
45
80
def __init__ (
46
81
self ,
82
+ model ,
83
+ optimizer ,
84
+ args = None ,
85
+ model_ema = None ,
86
+ amp_scaler = None ,
47
87
checkpoint_prefix = 'checkpoint' ,
48
88
recovery_prefix = 'recovery' ,
49
89
checkpoint_dir = '' ,
50
90
recovery_dir = '' ,
51
91
decreasing = False ,
52
92
max_history = 10 ,
53
- save_amp = False ):
93
+ unwrap_fn = unwrap_model ):
94
+
95
+ # objects to save state_dicts of
96
+ self .model = model
97
+ self .optimizer = optimizer
98
+ self .args = args
99
+ self .model_ema = model_ema
100
+ self .amp_scaler = amp_scaler
54
101
55
102
# state
56
103
self .checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
@@ -68,14 +115,14 @@ def __init__(
68
115
self .decreasing = decreasing # a lower metric is better if True
69
116
self .cmp = operator .lt if decreasing else operator .gt # True if lhs better than rhs
70
117
self .max_history = max_history
71
- self .save_apex_amp = save_amp # save APEX amp state
118
+ self .unwrap_fn = unwrap_fn
72
119
assert self .max_history >= 1
73
120
74
- def save_checkpoint (self , model , optimizer , args , epoch , model_ema = None , metric = None ):
121
+ def save_checkpoint (self , epoch , metric = None ):
75
122
assert epoch >= 0
76
123
tmp_save_path = os .path .join (self .checkpoint_dir , 'tmp' + self .extension )
77
124
last_save_path = os .path .join (self .checkpoint_dir , 'last' + self .extension )
78
- self ._save (tmp_save_path , model , optimizer , args , epoch , model_ema , metric )
125
+ self ._save (tmp_save_path , epoch , metric )
79
126
if os .path .exists (last_save_path ):
80
127
os .unlink (last_save_path ) # required for Windows support.
81
128
os .rename (tmp_save_path , last_save_path )
@@ -107,19 +154,21 @@ def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=
107
154
108
155
return (None , None ) if self .best_metric is None else (self .best_metric , self .best_epoch )
109
156
110
- def _save (self , save_path , model , optimizer , args , epoch , model_ema = None , metric = None ):
157
+ def _save (self , save_path , epoch , metric = None ):
111
158
save_state = {
112
159
'epoch' : epoch ,
113
- 'arch' : args .model ,
114
- 'state_dict' : get_state_dict (model ),
115
- 'optimizer' : optimizer .state_dict (),
116
- 'args' : args ,
160
+ 'arch' : type (self .model ).__name__ .lower (),
161
+ 'state_dict' : get_state_dict (self .model , self .unwrap_fn ),
162
+ 'optimizer' : self .optimizer .state_dict (),
117
163
'version' : 2 , # version < 2 increments epoch before save
118
164
}
119
- if self .save_apex_amp and 'state_dict' in amp .__dict__ :
120
- save_state ['amp' ] = amp .state_dict ()
121
- if model_ema is not None :
122
- save_state ['state_dict_ema' ] = get_state_dict (model_ema )
165
+ if self .args is not None :
166
+ save_state ['arch' ] = self .args .model
167
+ save_state ['args' ] = self .args
168
+ if self .amp_scaler is not None :
169
+ save_state [self .amp_scaler .state_dict_key ] = self .amp_scaler .state_dict ()
170
+ if self .model_ema is not None :
171
+ save_state ['state_dict_ema' ] = get_state_dict (self .model_ema , self .unwrap_fn )
123
172
if metric is not None :
124
173
save_state ['metric' ] = metric
125
174
torch .save (save_state , save_path )
@@ -138,11 +187,11 @@ def _cleanup_checkpoints(self, trim=0):
138
187
_logger .error ("Exception '{}' while deleting checkpoint" .format (e ))
139
188
self .checkpoint_files = self .checkpoint_files [:delete_index ]
140
189
141
- def save_recovery (self , model , optimizer , args , epoch , model_ema = None , batch_idx = 0 ):
190
+ def save_recovery (self , epoch , batch_idx = 0 ):
142
191
assert epoch >= 0
143
192
filename = '-' .join ([self .recovery_prefix , str (epoch ), str (batch_idx )]) + self .extension
144
193
save_path = os .path .join (self .recovery_dir , filename )
145
- self ._save (save_path , model , optimizer , args , epoch , model_ema )
194
+ self ._save (save_path , epoch )
146
195
if os .path .exists (self .last_recovery_file ):
147
196
try :
148
197
_logger .debug ("Cleaning recovery: {}" .format (self .last_recovery_file ))
@@ -336,3 +385,16 @@ def add_bool_arg(parser, name, default=False, help=''):
336
385
group .add_argument ('--' + name , dest = dest_name , action = 'store_true' , help = help )
337
386
group .add_argument ('--no-' + name , dest = dest_name , action = 'store_false' , help = help )
338
387
parser .set_defaults (** {dest_name : default })
388
+
389
+
390
+ def set_jit_legacy ():
391
+ """ Set JIT executor to legacy w/ support for op fusion
392
+ This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
393
+ in the JIT exectutor. These API are not supported so could change.
394
+ """
395
+ #
396
+ assert hasattr (torch ._C , '_jit_set_profiling_executor' ), "Old JIT behavior doesn't exist!"
397
+ torch ._C ._jit_set_profiling_executor (False )
398
+ torch ._C ._jit_set_profiling_mode (False )
399
+ torch ._C ._jit_override_can_fuse_on_gpu (True )
400
+ #torch._C._jit_set_texpr_fuser_enabled(True)
0 commit comments