@@ -227,7 +227,9 @@ def __init__(
227
227
228
228
# decorate model(s) and optimizer(s) for AMP
229
229
if self .use_amp :
230
- self .model = amp .decorate (self .model , self .optimizer , self .amp_level )
230
+ self .model , self .optimizer = amp .decorate (
231
+ self .model , self .optimizer , self .amp_level
232
+ )
231
233
232
234
# wrap model and optimizer to parallel object
233
235
self .rank = dist .get_rank ()
@@ -247,8 +249,9 @@ def __init__(
247
249
if version .Version (paddle .__version__ ) != version .Version ("0.0.0" )
248
250
else f"develop({ paddle .version .commit [:7 ]} )"
249
251
)
250
- if logger ._logger is not None :
251
- logger .info (f"Using paddlepaddle { paddle_version } on device { self .device } " )
252
+ if logger ._logger is None :
253
+ logger .init_logger ("ppsci" )
254
+ logger .info (f"Using paddlepaddle { paddle_version } on device { self .device } " )
252
255
253
256
@staticmethod
254
257
def from_config (cfg : Dict [str , Any ]) -> Solver :
@@ -379,6 +382,7 @@ def train(self):
379
382
)
380
383
logger .scaler ("eval_metric" , cur_metric , epoch_id , self .vdl_writer )
381
384
385
+ # visualize after evaluation
382
386
if self .visualizer is not None :
383
387
self .visualize (epoch_id )
384
388
@@ -398,7 +402,7 @@ def train(self):
398
402
self .equation ,
399
403
)
400
404
401
- # always save the latest model for convenient resume training
405
+ # save the latest model for convenient resume training
402
406
save_load .save_checkpoint (
403
407
self .model ,
404
408
self .optimizer ,
@@ -413,12 +417,9 @@ def train(self):
413
417
if self .vdl_writer is not None :
414
418
self .vdl_writer .close ()
415
419
420
+ @misc .run_on_eval_mode
416
421
def eval (self , epoch_id : int = 0 ):
417
422
"""Evaluation"""
418
- train_state = self .model .training
419
- if train_state :
420
- self .model .eval ()
421
-
422
423
# set eval func
423
424
self .eval_func = ppsci .solver .eval .eval_func
424
425
@@ -429,26 +430,19 @@ def eval(self, epoch_id: int = 0):
429
430
logger .info (f"[Eval][Epoch { epoch_id } ][Avg] { metric_msg } " )
430
431
self .eval_output_info .clear ()
431
432
432
- if train_state :
433
- self .model .train ()
434
433
return result
435
434
435
+ @misc .run_on_eval_mode
436
436
def visualize (self , epoch_id : int = 0 ):
437
437
"""Visualization"""
438
- train_state = self .model .training
439
- if train_state :
440
- self .model .eval ()
441
-
442
438
# init train func
443
439
self .visu_func = ppsci .solver .visu .visualize_func
444
440
445
441
self .visu_func (self , epoch_id )
446
442
logger .info (f"[Visualize][Epoch { epoch_id } ] Finished visualization" )
447
443
448
- if train_state :
449
- self .model .train ()
450
-
451
444
@paddle .no_grad ()
445
+ @misc .run_on_eval_mode
452
446
def predict (
453
447
self ,
454
448
input_dict : Dict [str , Union [np .ndarray , paddle .Tensor ]],
@@ -463,10 +457,6 @@ def predict(
463
457
Returns:
464
458
Dict[str, paddle.Tensor]: Prediction in dict.
465
459
"""
466
- train_state = self .model .training
467
- if train_state :
468
- self .model .eval ()
469
-
470
460
if self .world_size > 1 :
471
461
raise NotImplementedError (
472
462
"Solver.predict only support single device yet, "
@@ -501,10 +491,9 @@ def predict(
501
491
502
492
pred_dict = {key : paddle .concat (value ) for key , value in pred_dict .items ()}
503
493
504
- if train_state :
505
- self .model .train ()
506
494
return pred_dict
507
495
496
+ @misc .run_on_eval_mode
508
497
def export (self ):
509
498
"""Export to inference model"""
510
499
pretrained_path = self .cfg ["Global" ]["pretrained_model" ]
0 commit comments