Skip to content

Commit 4dcb801

Browse files
optimize amp context code in train.py/eval.py
1 parent aa85dec commit 4dcb801

File tree

3 files changed

+22
-19
lines changed

3 files changed

+22
-19
lines changed

ppsci/solver/eval.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,7 @@ def eval_func(solver, epoch_id: int, log_freq: int) -> float:
6767
evaluator.add_target_expr(output_formula, output_name)
6868

6969
# forward
70-
if solver.use_amp:
71-
with amp.auto_cast(level=solver.amp_level):
72-
output_dict = evaluator(input_dict)
73-
validator_loss = _validator.loss(
74-
output_dict, label_dict, weight_dict
75-
)
76-
loss_dict[f"loss({_validator.name})"] = float(validator_loss)
77-
else:
70+
with solver._autocast_context_manager():
7871
output_dict = evaluator(input_dict)
7972
validator_loss = _validator.loss(output_dict, label_dict, weight_dict)
8073
loss_dict[f"loss({_validator.name})"] = float(validator_loss)

ppsci/solver/solver.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414

1515
from __future__ import annotations
1616

17+
import contextlib
1718
import copy
1819
import os
20+
import sys
1921
from typing import Any
2022
from typing import Dict
2123
from typing import Optional
24+
from typing import Union
2225

2326
import paddle
2427
import paddle.amp as amp
@@ -501,3 +504,20 @@ def export(self):
501504
save_path = os.path.join(export_dir, "inference")
502505
paddle.jit.save(static_model, save_path)
503506
logger.info(f"The inference model has been exported to {export_dir}.")
507+
508+
def _autocast_context_manager(self) -> contextlib.AbstractContextManager:
509+
"""Autocast context manager for Auto Mix Precision.
510+
511+
Returns:
512+
Union[contextlib.AbstractContextManager]: Context manager.
513+
"""
514+
if self.use_amp:
515+
ctx_manager = amp.auto_cast(level=self.amp_level)
516+
else:
517+
ctx_manager = (
518+
contextlib.nullcontext()
519+
if sys.version_info >= (3, 7)
520+
else contextlib.suppress()
521+
)
522+
523+
return ctx_manager

ppsci/solver/train.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414

1515
import time
1616

17-
import paddle.amp as amp
18-
19-
from ppsci import solver
2017
from ppsci.solver import printer
2118
from ppsci.utils import expression
2219
from ppsci.utils import misc
@@ -62,14 +59,7 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
6259
evaluator.add_target_expr(output_formula, output_name)
6360

6461
# forward for every constraint
65-
if solver.use_amp:
66-
with amp.auto_cast(level=solver.amp_level):
67-
output_dict = evaluator(input_dict)
68-
constraint_loss = _constraint.loss(
69-
output_dict, label_dict, weight_dict
70-
)
71-
total_loss += constraint_loss
72-
else:
62+
with solver._autocast_context_manager():
7363
output_dict = evaluator(input_dict)
7464
constraint_loss = _constraint.loss(output_dict, label_dict, weight_dict)
7565
total_loss += constraint_loss

0 commit comments

Comments
 (0)