@@ -227,7 +227,9 @@ def __init__(
227
227
228
228
# decorate model(s) and optimizer(s) for AMP
229
229
if self .use_amp :
230
- self .model = amp .decorate (self .model , self .optimizer , self .amp_level )
230
+ self .model , self .optimizer = amp .decorate (
231
+ self .model , self .optimizer , self .amp_level
232
+ )
231
233
232
234
# wrap model and optimizer to parallel object
233
235
self .rank = dist .get_rank ()
@@ -379,6 +381,7 @@ def train(self):
379
381
)
380
382
logger .scaler ("eval_metric" , cur_metric , epoch_id , self .vdl_writer )
381
383
384
+ # visualize after evaluation
382
385
if self .visualizer is not None :
383
386
self .visualize (epoch_id )
384
387
@@ -398,7 +401,7 @@ def train(self):
398
401
self .equation ,
399
402
)
400
403
401
- # always save the latest model for convenient resume training
404
+ # save the latest model for convenient resume training
402
405
save_load .save_checkpoint (
403
406
self .model ,
404
407
self .optimizer ,
@@ -413,12 +416,9 @@ def train(self):
413
416
if self .vdl_writer is not None :
414
417
self .vdl_writer .close ()
415
418
419
+ @misc .run_on_eval_mode
416
420
def eval (self , epoch_id : int = 0 ):
417
421
"""Evaluation"""
418
- train_state = self .model .training
419
- if train_state :
420
- self .model .eval ()
421
-
422
422
# set eval func
423
423
self .eval_func = ppsci .solver .eval .eval_func
424
424
@@ -429,26 +429,19 @@ def eval(self, epoch_id: int = 0):
429
429
logger .info (f"[Eval][Epoch { epoch_id } ][Avg] { metric_msg } " )
430
430
self .eval_output_info .clear ()
431
431
432
- if train_state :
433
- self .model .train ()
434
432
return result
435
433
434
+ @misc .run_on_eval_mode
436
435
def visualize (self , epoch_id : int = 0 ):
437
436
"""Visualization"""
438
- train_state = self .model .training
439
- if train_state :
440
- self .model .eval ()
441
-
442
437
# init train func
443
438
self .visu_func = ppsci .solver .visu .visualize_func
444
439
445
440
self .visu_func (self , epoch_id )
446
441
logger .info (f"[Visualize][Epoch { epoch_id } ] Finished visualization" )
447
442
448
- if train_state :
449
- self .model .train ()
450
-
451
443
@paddle .no_grad ()
444
+ @misc .run_on_eval_mode
452
445
def predict (
453
446
self ,
454
447
input_dict : Dict [str , Union [np .ndarray , paddle .Tensor ]],
@@ -463,10 +456,6 @@ def predict(
463
456
Returns:
464
457
Dict[str, paddle.Tensor]: Prediction in dict.
465
458
"""
466
- train_state = self .model .training
467
- if train_state :
468
- self .model .eval ()
469
-
470
459
if self .world_size > 1 :
471
460
raise NotImplementedError (
472
461
"Solver.predict only support single device yet, "
@@ -501,10 +490,9 @@ def predict(
501
490
502
491
pred_dict = {key : paddle .concat (value ) for key , value in pred_dict .items ()}
503
492
504
- if train_state :
505
- self .model .train ()
506
493
return pred_dict
507
494
495
+ @misc .run_on_eval_mode
508
496
def export (self ):
509
497
"""Export to inference model"""
510
498
pretrained_path = self .cfg ["Global" ]["pretrained_model" ]
0 commit comments