@@ -34,11 +34,11 @@ def _mkdir_if_not_exist(path):
34
34
if not os .path .exists (path ):
35
35
try :
36
36
os .makedirs (path )
37
- except OSError as e :
38
- if e .errno == errno .EEXIST and os .path .isdir (path ):
39
- logger .warning (f"{ path } already created" )
37
+ except OSError as os_err :
38
+ if os_err .errno == errno .EEXIST and os .path .isdir (path ):
39
+ logger .warning (f"{ path } already created. " )
40
40
else :
41
- raise OSError (f"Failed to mkdir { path } " )
41
+ raise OSError (f"Failed to mkdir { path } . " )
42
42
43
43
44
44
def _load_pretrain_from_path (model , path , equation = None ):
@@ -49,17 +49,20 @@ def _load_pretrain_from_path(model, path, equation=None):
49
49
path (str, optional): Pretrained model path.
50
50
equation (Optional[Dict[str, ppsci.equation.PDE]]): Equations. Defaults to None.
51
51
"""
52
- if not (os .path .isdir (path ) or os .path .exists (path + " .pdparams" )):
52
+ if not (os .path .isdir (path ) or os .path .exists (f" { path } .pdparams" )):
53
53
raise FileNotFoundError (
54
54
f"Pretrained model path { path } .pdparams does not exists."
55
55
)
56
56
57
- param_state_dict = paddle .load (path + " .pdparams" )
57
+ param_state_dict = paddle .load (f" { path } .pdparams" )
58
58
model .set_dict (param_state_dict )
59
59
if equation is not None :
60
- equation_dict = paddle .load (path + ".pdeqn" )
61
- for name , _equation in equation .items ():
62
- _equation .set_state_dict (equation_dict [name ])
60
+ if not os .path .exists (f"{ path } .pdeqn" ):
61
+ logger .warning (f"{ path } .pdeqn not found." )
62
+ else :
63
+ equation_dict = paddle .load (f"{ path } .pdeqn" )
64
+ for name , _equation in equation .items ():
65
+ _equation .set_state_dict (equation_dict [name ])
63
66
64
67
logger .info (f"Finish loading pretrained model from { path } " )
65
68
@@ -92,28 +95,32 @@ def load_checkpoint(
92
95
Returns:
93
96
Dict[str, Any]: Loaded metric information.
94
97
"""
95
- if not os .path .exists (path + " .pdparams" ):
98
+ if not os .path .exists (f" { path } .pdparams" ):
96
99
raise FileNotFoundError (f"{ path } .pdparams not exist." )
97
- if not os .path .exists (path + " .pdopt" ):
100
+ if not os .path .exists (f" { path } .pdopt" ):
98
101
raise FileNotFoundError (f"{ path } .pdopt not exist." )
99
- if grad_scaler is not None and not os .path .exists (path + " .pdscaler" ):
102
+ if grad_scaler is not None and not os .path .exists (f" { path } .pdscaler" ):
100
103
raise FileNotFoundError (f"{ path } .scaler not exist." )
101
104
102
105
# load state dict
103
- param_dict = paddle .load (path + " .pdparams" )
104
- optim_dict = paddle .load (path + " .pdopt" )
105
- metric_dict = paddle .load (path + " .pdstates" )
106
+ param_dict = paddle .load (f" { path } .pdparams" )
107
+ optim_dict = paddle .load (f" { path } .pdopt" )
108
+ metric_dict = paddle .load (f" { path } .pdstates" )
106
109
if grad_scaler is not None :
107
- scaler_dict = paddle .load (path + " .pdscaler" )
110
+ scaler_dict = paddle .load (f" { path } .pdscaler" )
108
111
if equation is not None :
109
- equation_dict = paddle .load (path + ".pdeqn" )
112
+ if not os .path .exists (f"{ path } .pdeqn" ):
113
+ logger .warning (f"{ path } .pdeqn not found." )
114
+ equation_dict = None
115
+ else :
116
+ equation_dict = paddle .load (f"{ path } .pdeqn" )
110
117
111
118
# set state dict
112
119
model .set_state_dict (param_dict )
113
120
optimizer .set_state_dict (optim_dict )
114
121
if grad_scaler is not None :
115
122
grad_scaler .load_state_dict (scaler_dict )
116
- if equation is not None :
123
+ if equation is not None and equation_dict is not None :
117
124
for name , _equation in equation .items ():
118
125
_equation .set_state_dict (equation_dict [name ])
119
126
@@ -141,15 +148,15 @@ def save_checkpoint(
141
148
_mkdir_if_not_exist (model_dir )
142
149
model_path = os .path .join (model_dir , prefix )
143
150
144
- paddle .save (model .state_dict (), model_path + " .pdparams" )
145
- paddle .save (optimizer .state_dict (), model_path + " .pdopt" )
146
- paddle .save (metric , model_path + " .pdstates" )
151
+ paddle .save (model .state_dict (), f" { model_path } .pdparams" )
152
+ paddle .save (optimizer .state_dict (), f" { model_path } .pdopt" )
153
+ paddle .save (metric , f" { model_path } .pdstates" )
147
154
if grad_scaler is not None :
148
- paddle .save (grad_scaler .state_dict (), model_path + " .pdscaler" )
155
+ paddle .save (grad_scaler .state_dict (), f" { model_path } .pdscaler" )
149
156
if equation is not None :
150
157
paddle .save (
151
158
{key : eq .state_dict () for key , eq in equation .items ()},
152
- model_path + " .pdeqn" ,
159
+ f" { model_path } .pdeqn" ,
153
160
)
154
161
155
162
logger .info (f"Finish saving checkpoint to { model_path } " )
0 commit comments