Skip to content

Commit 848f2bf

Browse files
add eval wrapper to optimize eval code in solver
1 parent f3f554c commit 848f2bf

File tree

2 files changed

+45
-23
lines changed

2 files changed

+45
-23
lines changed

ppsci/solver/solver.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def __init__(
227227

228228
# decorate model(s) and optimizer(s) for AMP
229229
if self.use_amp:
230-
self.model = amp.decorate(self.model, self.optimizer, self.amp_level)
230+
self.model, self.optimizer = amp.decorate(
231+
self.model, self.optimizer, self.amp_level
232+
)
231233

232234
# wrap model and optimizer to parallel object
233235
self.rank = dist.get_rank()
@@ -247,8 +249,9 @@ def __init__(
247249
if version.Version(paddle.__version__) != version.Version("0.0.0")
248250
else f"develop({paddle.version.commit[:7]})"
249251
)
250-
if logger._logger is not None:
251-
logger.info(f"Using paddlepaddle {paddle_version} on device {self.device}")
252+
if logger._logger is None:
253+
logger.init_logger("ppsci")
254+
logger.info(f"Using paddlepaddle {paddle_version} on device {self.device}")
252255

253256
@staticmethod
254257
def from_config(cfg: Dict[str, Any]) -> Solver:
@@ -379,6 +382,7 @@ def train(self):
379382
)
380383
logger.scaler("eval_metric", cur_metric, epoch_id, self.vdl_writer)
381384

385+
# visualize after evaluation
382386
if self.visualizer is not None:
383387
self.visualize(epoch_id)
384388

@@ -398,7 +402,7 @@ def train(self):
398402
self.equation,
399403
)
400404

401-
# always save the latest model for convenient resume training
405+
# save the latest model for convenient resume training
402406
save_load.save_checkpoint(
403407
self.model,
404408
self.optimizer,
@@ -413,12 +417,9 @@ def train(self):
413417
if self.vdl_writer is not None:
414418
self.vdl_writer.close()
415419

420+
@misc.run_on_eval_mode
416421
def eval(self, epoch_id: int = 0):
417422
"""Evaluation"""
418-
train_state = self.model.training
419-
if train_state:
420-
self.model.eval()
421-
422423
# set eval func
423424
self.eval_func = ppsci.solver.eval.eval_func
424425

@@ -429,26 +430,19 @@ def eval(self, epoch_id: int = 0):
429430
logger.info(f"[Eval][Epoch {epoch_id}][Avg] {metric_msg}")
430431
self.eval_output_info.clear()
431432

432-
if train_state:
433-
self.model.train()
434433
return result
435434

435+
@misc.run_on_eval_mode
436436
def visualize(self, epoch_id: int = 0):
437437
"""Visualization"""
438-
train_state = self.model.training
439-
if train_state:
440-
self.model.eval()
441-
442438
# init train func
443439
self.visu_func = ppsci.solver.visu.visualize_func
444440

445441
self.visu_func(self, epoch_id)
446442
logger.info(f"[Visualize][Epoch {epoch_id}] Finished visualization")
447443

448-
if train_state:
449-
self.model.train()
450-
451444
@paddle.no_grad()
445+
@misc.run_on_eval_mode
452446
def predict(
453447
self,
454448
input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
@@ -463,10 +457,6 @@ def predict(
463457
Returns:
464458
Dict[str, paddle.Tensor]: Prediction in dict.
465459
"""
466-
train_state = self.model.training
467-
if train_state:
468-
self.model.eval()
469-
470460
if self.world_size > 1:
471461
raise NotImplementedError(
472462
"Solver.predict only support single device yet, "
@@ -501,10 +491,9 @@ def predict(
501491

502492
pred_dict = {key: paddle.concat(value) for key, value in pred_dict.items()}
503493

504-
if train_state:
505-
self.model.train()
506494
return pred_dict
507495

496+
@misc.run_on_eval_mode
508497
def export(self):
509498
"""Export to inference model"""
510499
pretrained_path = self.cfg["Global"]["pretrained_model"]

ppsci/utils/misc.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# limitations under the License.
1414

1515
import collections
16+
import functools
1617
import random
18+
from typing import Callable
1719
from typing import Dict
1820
from typing import List
1921
from typing import Tuple
@@ -32,6 +34,7 @@
3234
"stack_dict_list",
3335
"combine_array_with_time",
3436
"set_random_seed",
37+
"run_on_eval_mode",
3538
]
3639

3740

@@ -229,3 +232,33 @@ def set_random_seed(seed: int):
229232
paddle.seed(seed)
230233
np.random.seed(seed)
231234
random.seed(seed)
235+
236+
237+
def run_on_eval_mode(func: Callable) -> Callable:
238+
"""A decorator automatically running given class method in eval mode and keep
239+
training state unchanged after function finished.
240+
241+
Args:
242+
func (Callable): Class method which is expected running in eval mode.
243+
244+
Returns:
245+
Callable: Decorated class method.
246+
"""
247+
248+
@functools.wraps(func)
249+
def function_with_eval_state(self, *args, **kwargs):
250+
# log original state
251+
train_state = self.model.training
252+
253+
# switch to eval mode
254+
if train_state:
255+
self.model.eval()
256+
257+
# run func in eval mode
258+
func()
259+
260+
# restore state
261+
if train_state:
262+
self.model.train()
263+
264+
return function_with_eval_state

0 commit comments

Comments
 (0)