Skip to content

Commit 3d4137b

Browse files
Merge pull request #336 from HydrogenSulfate/add_eval_decorator_2
add eval wrapper to optimize eval code in solver
2 parents 0f397b5 + 91e3f3c commit 3d4137b

File tree

2 files changed

+42
-21
lines changed

2 files changed

+42
-21
lines changed

ppsci/solver/solver.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def __init__(
227227

228228
# decorate model(s) and optimizer(s) for AMP
229229
if self.use_amp:
230-
self.model = amp.decorate(self.model, self.optimizer, self.amp_level)
230+
self.model, self.optimizer = amp.decorate(
231+
self.model, self.optimizer, self.amp_level
232+
)
231233

232234
# wrap model and optimizer to parallel object
233235
self.rank = dist.get_rank()
@@ -379,6 +381,7 @@ def train(self):
379381
)
380382
logger.scaler("eval_metric", cur_metric, epoch_id, self.vdl_writer)
381383

384+
# visualize after evaluation
382385
if self.visualizer is not None:
383386
self.visualize(epoch_id)
384387

@@ -398,7 +401,7 @@ def train(self):
398401
self.equation,
399402
)
400403

401-
# always save the latest model for convenient resume training
404+
# save the latest model for convenient resume training
402405
save_load.save_checkpoint(
403406
self.model,
404407
self.optimizer,
@@ -413,12 +416,9 @@ def train(self):
413416
if self.vdl_writer is not None:
414417
self.vdl_writer.close()
415418

419+
@misc.run_on_eval_mode
416420
def eval(self, epoch_id: int = 0):
417421
"""Evaluation"""
418-
train_state = self.model.training
419-
if train_state:
420-
self.model.eval()
421-
422422
# set eval func
423423
self.eval_func = ppsci.solver.eval.eval_func
424424

@@ -429,26 +429,19 @@ def eval(self, epoch_id: int = 0):
429429
logger.info(f"[Eval][Epoch {epoch_id}][Avg] {metric_msg}")
430430
self.eval_output_info.clear()
431431

432-
if train_state:
433-
self.model.train()
434432
return result
435433

434+
@misc.run_on_eval_mode
436435
def visualize(self, epoch_id: int = 0):
437436
"""Visualization"""
438-
train_state = self.model.training
439-
if train_state:
440-
self.model.eval()
441-
442437
# init train func
443438
self.visu_func = ppsci.solver.visu.visualize_func
444439

445440
self.visu_func(self, epoch_id)
446441
logger.info(f"[Visualize][Epoch {epoch_id}] Finished visualization")
447442

448-
if train_state:
449-
self.model.train()
450-
451443
@paddle.no_grad()
444+
@misc.run_on_eval_mode
452445
def predict(
453446
self,
454447
input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
@@ -463,10 +456,6 @@ def predict(
463456
Returns:
464457
Dict[str, paddle.Tensor]: Prediction in dict.
465458
"""
466-
train_state = self.model.training
467-
if train_state:
468-
self.model.eval()
469-
470459
if self.world_size > 1:
471460
raise NotImplementedError(
472461
"Solver.predict only support single device yet, "
@@ -501,10 +490,9 @@ def predict(
501490

502491
pred_dict = {key: paddle.concat(value) for key, value in pred_dict.items()}
503492

504-
if train_state:
505-
self.model.train()
506493
return pred_dict
507494

495+
@misc.run_on_eval_mode
508496
def export(self):
509497
"""Export to inference model"""
510498
pretrained_path = self.cfg["Global"]["pretrained_model"]

ppsci/utils/misc.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# limitations under the License.
1414

1515
import collections
16+
import functools
1617
import random
18+
from typing import Callable
1719
from typing import Dict
1820
from typing import List
1921
from typing import Tuple
@@ -32,6 +34,7 @@
3234
"stack_dict_list",
3335
"combine_array_with_time",
3436
"set_random_seed",
37+
"run_on_eval_mode",
3538
]
3639

3740

@@ -229,3 +232,33 @@ def set_random_seed(seed: int):
229232
paddle.seed(seed)
230233
np.random.seed(seed)
231234
random.seed(seed)
235+
236+
237+
def run_on_eval_mode(func: Callable) -> Callable:
238+
"""A decorator automatically running given class method in eval mode and keep
239+
training state unchanged after function finished.
240+
241+
Args:
242+
func (Callable): Class method which is expected running in eval mode.
243+
244+
Returns:
245+
Callable: Decorated class method.
246+
"""
247+
248+
@functools.wraps(func)
249+
def function_with_eval_state(self, *args, **kwargs):
250+
# log original state
251+
train_state = self.model.training
252+
253+
# switch to eval mode
254+
if train_state:
255+
self.model.eval()
256+
257+
# run func in eval mode
258+
func(self, *args, **kwargs)
259+
260+
# restore state
261+
if train_state:
262+
self.model.train()
263+
264+
return function_with_eval_state

0 commit comments

Comments
 (0)