Skip to content

Commit 99a2a80

Browse files
Merge pull request #290 from HydrogenSulfate/opt_train_amp
optimize amp context code in train.py/eval.py
2 parents 543be8f + a36702b commit 99a2a80

File tree

16 files changed

+46
-52
lines changed

16 files changed

+46
-52
lines changed

ppsci/arch/activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from typing import Callable
1616

1717
import paddle
18-
import paddle.nn as nn
1918
import paddle.nn.functional as F
19+
from paddle import nn
2020

2121
act_func_dict = {
2222
"elu": F.elu,

ppsci/arch/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import numpy as np
1919
import paddle
20-
import paddle.nn as nn
20+
from paddle import nn
2121

2222
from ppsci.utils import logger
2323

ppsci/arch/embedding_koopman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import numpy as np
2323
import paddle
24-
import paddle.nn as nn
24+
from paddle import nn
2525
from paddle.nn.initializer import Constant
2626
from paddle.nn.initializer import Uniform
2727

ppsci/arch/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Tuple
1616
from typing import Union
1717

18-
import paddle.nn as nn
18+
from paddle import nn
1919

2020
from ppsci.arch import activation as act_mod
2121
from ppsci.arch import base

ppsci/data/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from functools import partial
1818

1919
import numpy as np
20-
import paddle.device as device
2120
import paddle.distributed as dist
22-
import paddle.io as io
21+
from paddle import device
22+
from paddle import io
2323

2424
from ppsci.data import dataloader
2525
from ppsci.data import dataset

ppsci/equation/pde/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from typing import Tuple
1919

2020
import paddle
21-
import paddle.nn as nn
2221
import sympy
22+
from paddle import nn
2323

2424

2525
class PDE:

ppsci/loss/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import paddle.nn as nn
15+
from paddle import nn
1616

1717

1818
class Loss(nn.Layer):

ppsci/metric/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import paddle.nn as nn
15+
from paddle import nn
1616

1717

1818
class Metric(nn.Layer):

ppsci/solver/eval.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
import time
1616

1717
import paddle
18-
import paddle.amp as amp
19-
import paddle.io as io
18+
from paddle import io
2019

2120
from ppsci.solver import printer
2221
from ppsci.utils import expression
@@ -67,14 +66,7 @@ def eval_func(solver, epoch_id: int, log_freq: int) -> float:
6766
evaluator.add_target_expr(output_formula, output_name)
6867

6968
# forward
70-
if solver.use_amp:
71-
with amp.auto_cast(level=solver.amp_level):
72-
output_dict = evaluator(input_dict)
73-
validator_loss = _validator.loss(
74-
output_dict, label_dict, weight_dict
75-
)
76-
loss_dict[f"loss({_validator.name})"] = float(validator_loss)
77-
else:
69+
with solver._autocast_context_manager():
7870
output_dict = evaluator(input_dict)
7971
validator_loss = _validator.loss(output_dict, label_dict, weight_dict)
8072
loss_dict[f"loss({_validator.name})"] = float(validator_loss)

ppsci/solver/solver.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,22 @@
1414

1515
from __future__ import annotations
1616

17+
import contextlib
1718
import copy
1819
import os
20+
import sys
1921
from typing import Any
2022
from typing import Dict
2123
from typing import Optional
2224

2325
import paddle
24-
import paddle.amp as amp
2526
import paddle.distributed as dist
26-
import paddle.incubate as incubate
27-
import paddle.nn as nn
28-
import paddle.optimizer as optimizer
2927
import visualdl as vdl
3028
from packaging import version
29+
from paddle import amp
30+
from paddle import incubate
31+
from paddle import nn
32+
from paddle import optimizer
3133
from paddle.distributed import fleet
3234
from typing_extensions import Literal
3335

@@ -57,7 +59,7 @@ class Solver:
5759
eval_freq (int, optional): Evaluation frequency. Defaults to 1.
5860
seed (int, optional): Random seed. Defaults to 42.
5961
vdl_writer (Optional[vdl.LogWriter]): VisualDL writer object. Defaults to None.
60-
device (Literal["cpu", "gpu", "xpu"], optional): _description_. Defaults to "gpu".
62+
device (Literal["cpu", "gpu", "xpu"], optional): Runtime device. Defaults to "gpu".
6163
equation (Optional[Dict[str, ppsci.equation.PDE]]): Equation dict. Defaults to None.
6264
geom (Optional[Dict[str, ppsci.geometry.Geometry]]): Geometry dict. Defaults to None.
6365
validator (Optional[Dict[str, ppsci.validate.Validator]]): Validator dict. Defaults to None.
@@ -469,10 +471,7 @@ def predict(
469471
batch_input_dict[key].stop_gradient = False
470472

471473
# forward
472-
if self.use_amp:
473-
with amp.auto_cast(level=self.amp_level):
474-
batch_output_dict = self.model(batch_input_dict)
475-
else:
474+
with self._autocast_context_manager():
476475
batch_output_dict = self.model(batch_input_dict)
477476

478477
# collect batch data
@@ -501,3 +500,20 @@ def export(self):
501500
save_path = os.path.join(export_dir, "inference")
502501
paddle.jit.save(static_model, save_path)
503502
logger.info(f"The inference model has been exported to {export_dir}.")
503+
504+
def _autocast_context_manager(self) -> contextlib.AbstractContextManager:
505+
"""Autocast context manager for Auto Mix Precision.
506+
507+
Returns:
508+
Union[contextlib.AbstractContextManager]: Context manager.
509+
"""
510+
if self.use_amp:
511+
ctx_manager = amp.auto_cast(level=self.amp_level)
512+
else:
513+
ctx_manager = (
514+
contextlib.nullcontext()
515+
if sys.version_info >= (3, 7)
516+
else contextlib.suppress()
517+
)
518+
519+
return ctx_manager

0 commit comments

Comments
 (0)