Skip to content

Commit ca39b4b

Browse files
add OptimizerList for multiple optimizer case (#348)
* add OptimizerList for multiple optimizer case * refine docstring and type-hint * refine solver * remoe export code for it is not supported yet * remove redundant f" * refine docstring and add OptimizerList to api doc * refine code with flake8
1 parent 4fa8897 commit ca39b4b

File tree

9 files changed

+157
-44
lines changed

9 files changed

+157
-44
lines changed

docs/zh/api/optimizer.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
- AdamW
1111
- RMSProp
1212
- LBFGS
13+
- OptimizerList
1314
show_root_heading: false
1415
heading_level: 3

ppsci/geometry/mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, mesh: Union["pymesh.Mesh", str]):
4848
elif isinstance(mesh, pymesh.Mesh):
4949
self.py_mesh = mesh
5050
else:
51-
raise ValueError(f"arg `mesh` should be path string or or `pymesh.Mesh`")
51+
raise ValueError("arg `mesh` should be path string or or `pymesh.Mesh`")
5252

5353
self.init_mesh()
5454

ppsci/optimizer/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,19 @@
2020
from ppsci.optimizer.optimizer import Adam
2121
from ppsci.optimizer.optimizer import AdamW
2222
from ppsci.optimizer.optimizer import Momentum
23+
from ppsci.optimizer.optimizer import OptimizerList
2324
from ppsci.optimizer.optimizer import RMSProp
2425

25-
__all__ = ["LBFGS", "SGD", "Adam", "AdamW", "Momentum", "RMSProp", "lr_scheduler"]
26+
__all__ = [
27+
"LBFGS",
28+
"SGD",
29+
"Adam",
30+
"AdamW",
31+
"Momentum",
32+
"RMSProp",
33+
"OptimizerList",
34+
"lr_scheduler",
35+
]
2636

2737

2838
def build_lr_scheduler(cfg, epochs, iters_per_epoch):

ppsci/optimizer/optimizer.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import TYPE_CHECKING
16+
from typing import Dict
17+
from typing import List
1518
from typing import Optional
1619
from typing import Tuple
1720
from typing import Union
1821

22+
if TYPE_CHECKING:
23+
import paddle
24+
1925
from paddle import nn
2026
from paddle import optimizer as optim
2127
from paddle import regularizer
2228
from paddle.incubate import optimizer as incubate_optim
2329
from typing_extensions import Literal
2430

2531
from ppsci.utils import logger
32+
from ppsci.utils import misc
2633

27-
__all__ = ["SGD", "Momentum", "Adam", "RMSProp", "AdamW", "LBFGS"]
34+
__all__ = ["SGD", "Momentum", "Adam", "RMSProp", "AdamW", "LBFGS", "OptimizerList"]
2835

2936

3037
class SGD:
@@ -452,3 +459,55 @@ def __call__(self, model_list: Tuple[nn.Layer, ...]):
452459

453460
def _apply_decay_param_fun(self, name):
454461
return name not in self.no_weight_decay_param_name_list
462+
463+
464+
class OptimizerList:
465+
"""OptimizerList which wrap more than one optimizer.
466+
NOTE: LBFGS is not supported yet.
467+
468+
Args:
469+
optimizer_list (Tuple[optim.Optimizer, ...]): Optimizers listed in a tuple.
470+
471+
Examples:
472+
>>> import ppsci
473+
>>> model1 = ppsci.arch.MLP(("x",), ("u",), 5, 20)
474+
>>> opt1 = ppsci.optimizer.Adam(1e-3)((model1,))
475+
>>> model2 = ppsci.arch.MLP(("y",), ("v",), 5, 20)
476+
>>> opt2 = ppsci.optimizer.Adam(1e-3)((model2,))
477+
>>> opt = ppsci.optimizer.OptimizerList((opt1, opt2))
478+
"""
479+
480+
def __init__(self, optimizer_list: Tuple[optim.Optimizer, ...]):
481+
super().__init__()
482+
self._opt_list = optimizer_list
483+
if "LBFGS" in set(misc.typename(opt) for opt in optimizer_list):
484+
raise ValueError("LBFGS is not supported in OptimizerList yet.")
485+
486+
def step(self):
487+
for opt in self._opt_list:
488+
opt.step()
489+
490+
def clear_grad(self):
491+
for opt in self._opt_list:
492+
opt.clear_grad()
493+
494+
def get_lr(self) -> float:
495+
"""Return learning rate of first optimizer"""
496+
return self._opt_list[0].get_lr()
497+
498+
def set_state_dict(self, state_dicts: List[Dict[str, "paddle.Tensor"]]):
499+
for i, opt in enumerate(self._opt_list):
500+
opt.set_state_dict(state_dicts[i])
501+
502+
def state_dict(self) -> List[Dict[str, "paddle.Tensor"]]:
503+
state_dicts = [opt.state_dict() for opt in self._opt_list]
504+
return state_dicts
505+
506+
def __len__(self) -> int:
507+
return len(self._opt_list)
508+
509+
def __getitem__(self, idx):
510+
return self._opt_list[idx]
511+
512+
def __setitem__(self, idx, opt):
513+
raise NotImplementedError("Can not modify any item in OptimizerList.")

ppsci/solver/eval.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import time
16+
from typing import TYPE_CHECKING
1617

1718
import paddle
1819
from paddle import io
@@ -21,8 +22,11 @@
2122
from ppsci.utils import misc
2223
from ppsci.utils import profiler
2324

25+
if TYPE_CHECKING:
26+
from ppsci import solver
2427

25-
def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
28+
29+
def _eval_by_dataset(solver: "solver.Solver", epoch_id: int, log_freq: int) -> float:
2630
"""Evaluate with computing metric on total samples.
2731
2832
Args:
@@ -148,7 +152,7 @@ def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
148152
return target_metric
149153

150154

151-
def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
155+
def _eval_by_batch(solver: "solver.Solver", epoch_id: int, log_freq: int) -> float:
152156
"""Evaluate with computing metric by batch, which is memory-efficient.
153157
154158
Args:
@@ -249,7 +253,7 @@ def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
249253
return target_metric
250254

251255

252-
def eval_func(solver, epoch_id: int, log_freq: int) -> float:
256+
def eval_func(solver: "solver.Solver", epoch_id: int, log_freq: int) -> float:
253257
"""Evaluation function.
254258
255259
Args:

ppsci/solver/solver.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from __future__ import annotations
1616

1717
import contextlib
18-
import copy
1918
import itertools
2019
import os
2120
import sys
@@ -51,7 +50,7 @@ class Solver:
5150
Args:
5251
model (nn.Layer): Model.
5352
constraint (Optional[Dict[str, ppsci.constraint.Constraint]]): Constraint(s) applied on model. Defaults to None.
54-
output_dir (str, optional): Output directory. Defaults to "./output/".
53+
output_dir (Optional[str]): Output directory. Defaults to "./output/".
5554
optimizer (Optional[optimizer.Optimizer]): Optimizer object. Defaults to None.
5655
lr_scheduler (Optional[optimizer.lr.LRScheduler]): Learning rate scheduler. Defaults to None.
5756
epochs (int, optional): Training epoch(s). Defaults to 5.
@@ -108,7 +107,7 @@ def __init__(
108107
self,
109108
model: nn.Layer,
110109
constraint: Optional[Dict[str, ppsci.constraint.Constraint]] = None,
111-
output_dir: str = "./output/",
110+
output_dir: Optional[str] = "./output/",
112111
optimizer: Optional[optim.Optimizer] = None,
113112
lr_scheduler: Optional[optim.lr.LRScheduler] = None,
114113
epochs: int = 5,
@@ -371,7 +370,7 @@ def from_config(cfg: Dict[str, Any]) -> Solver:
371370
)
372371

373372
def train(self):
374-
"""Training"""
373+
"""Training."""
375374
self.global_step = self.best_metric["epoch"] * self.iters_per_epoch + 1
376375

377376
for epoch_id in range(self.best_metric["epoch"] + 1, self.epochs + 1):
@@ -446,8 +445,15 @@ def train(self):
446445
self.vdl_writer.close()
447446

448447
@misc.run_on_eval_mode
449-
def eval(self, epoch_id: int = 0):
450-
"""Evaluation"""
448+
def eval(self, epoch_id: int = 0) -> float:
449+
"""Evaluation.
450+
451+
Args:
452+
epoch_id (int, optional): Epoch id. Defaults to 0.
453+
454+
Returns:
455+
float: The value of the evaluation, used to judge the quality of the model.
456+
"""
451457
# set eval func
452458
self.eval_func = ppsci.solver.eval.eval_func
453459

@@ -462,8 +468,12 @@ def eval(self, epoch_id: int = 0):
462468

463469
@misc.run_on_eval_mode
464470
def visualize(self, epoch_id: int = 0):
465-
"""Visualization"""
466-
# init train func
471+
"""Visualization.
472+
473+
Args:
474+
epoch_id (int, optional): Epoch id. Defaults to 0.
475+
"""
476+
# set visualize func
467477
self.visu_func = ppsci.solver.visu.visualize_func
468478

469479
self.visu_func(self, epoch_id)
@@ -568,21 +578,8 @@ def predict(
568578

569579
@misc.run_on_eval_mode
570580
def export(self):
571-
"""Export to inference model"""
572-
pretrained_path = self.cfg["Global"]["pretrained_model"]
573-
if pretrained_path is not None:
574-
save_load.load_pretrain(self.model, pretrained_path, self.equation)
575-
576-
self.model.eval()
577-
578-
input_spec = copy.deepcopy(self.cfg["Export"]["input_shape"])
579-
config.replace_shape_with_inputspec_(input_spec)
580-
static_model = jit.to_static(self.model, input_spec=input_spec)
581-
582-
export_dir = self.cfg["Global"]["save_inference_dir"]
583-
save_path = os.path.join(export_dir, "inference")
584-
jit.save(static_model, save_path)
585-
logger.info(f"The inference model has been exported to {export_dir}")
581+
"""Export to inference model."""
582+
raise NotImplementedError("model export is not supported yet.")
586583

587584
def autocast_context_manager(
588585
self, enable: bool, level: Literal["O0", "O1", "O2"] = "O1"

ppsci/solver/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
# limitations under the License.
1414

1515
import time
16+
from typing import TYPE_CHECKING
1617

1718
from paddle.distributed.fleet.utils import hybrid_parallel_util as hpu
1819

19-
from ppsci import solver
20+
if TYPE_CHECKING:
21+
from ppsci import solver
22+
2023
from ppsci.solver import printer
2124
from ppsci.utils import misc
2225
from ppsci.utils import profiler

ppsci/utils/expression.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,20 @@
1818
from typing import Optional
1919
from typing import Tuple
2020

21-
import paddle
2221
from paddle import jit
2322
from paddle import nn
2423

2524
if TYPE_CHECKING:
25+
import paddle
2626
from ppsci import constraint
2727
from ppsci import validate
2828

2929
from ppsci.autodiff import clear
3030

3131

3232
class ExpressionSolver(nn.Layer):
33-
"""Expression Solver
33+
"""Expression computing helper, which compute named result according to corresponding
34+
function and related inputs.
3435
3536
Examples:
3637
>>> import ppsci
@@ -45,12 +46,26 @@ def __init__(self):
4546
def train_forward(
4647
self,
4748
expr_dicts: Tuple[Dict[str, Callable], ...],
48-
input_dicts: Tuple[Dict[str, paddle.Tensor], ...],
49+
input_dicts: Tuple[Dict[str, "paddle.Tensor"], ...],
4950
model: nn.Layer,
5051
constraint: Dict[str, "constraint.Constraint"],
51-
label_dicts: Tuple[Dict[str, paddle.Tensor], ...],
52-
weight_dicts: Tuple[Dict[str, paddle.Tensor], ...],
53-
):
52+
label_dicts: Tuple[Dict[str, "paddle.Tensor"], ...],
53+
weight_dicts: Tuple[Dict[str, "paddle.Tensor"], ...],
54+
) -> Tuple["paddle.Tensor", ...]:
55+
"""Forward computation for training, including model forward and equation
56+
forward.
57+
58+
Args:
59+
expr_dicts (Tuple[Dict[str, Callable], ...]): Tuple of expression dicts.
60+
input_dicts (Tuple[Dict[str, paddle.Tensor], ...]): Tuple of input dicts.
61+
model (nn.Layer): NN model.
62+
constraint (Dict[str, "constraint.Constraint"]): Constraint dict.
63+
label_dicts (Tuple[Dict[str, paddle.Tensor], ...]): Tuple of label dicts.
64+
weight_dicts (Tuple[Dict[str, paddle.Tensor], ...]): Tuple of weight dicts.
65+
66+
Returns:
67+
Tuple[paddle.Tensor, ...]: Tuple of losses for each constraint.
68+
"""
5469
output_dicts = []
5570
for i, expr_dict in enumerate(expr_dicts):
5671
# model forward
@@ -90,12 +105,27 @@ def train_forward(
90105
def eval_forward(
91106
self,
92107
expr_dict: Dict[str, Callable],
93-
input_dict: Dict[str, paddle.Tensor],
108+
input_dict: Dict[str, "paddle.Tensor"],
94109
model: nn.Layer,
95110
validator: "validate.Validator",
96-
label_dict: Dict[str, paddle.Tensor],
97-
weight_dict: Dict[str, paddle.Tensor],
98-
):
111+
label_dict: Dict[str, "paddle.Tensor"],
112+
weight_dict: Dict[str, "paddle.Tensor"],
113+
) -> Tuple[Dict[str, "paddle.Tensor"], "paddle.Tensor"]:
114+
"""Forward computation for evaluation, including model forward and equation
115+
forward.
116+
117+
Args:
118+
expr_dict (Dict[str, Callable]): Expression dict.
119+
input_dict (Dict[str, paddle.Tensor]): Input dict.
120+
model (nn.Layer): NN model.
121+
validator (validate.Validator): Validator.
122+
label_dict (Dict[str, paddle.Tensor]): Label dict.
123+
weight_dict (Dict[str, paddle.Tensor]): Weight dict.
124+
125+
Returns:
126+
Tuple[Dict[str, paddle.Tensor], paddle.Tensor]: Result dict and loss for
127+
given validator.
128+
"""
99129
# model forward
100130
if callable(next(iter(expr_dict.values()))):
101131
output_dict = model(input_dict)
@@ -123,9 +153,20 @@ def eval_forward(
123153
def visu_forward(
124154
self,
125155
expr_dict: Optional[Dict[str, Callable]],
126-
input_dict: Dict[str, paddle.Tensor],
156+
input_dict: Dict[str, "paddle.Tensor"],
127157
model: nn.Layer,
128-
):
158+
) -> Dict[str, "paddle.Tensor"]:
159+
"""Forward computation for visualization, including model forward and equation
160+
forward.
161+
162+
Args:
163+
expr_dict (Optional[Dict[str, Callable]]): Expression dict.
164+
input_dict (Dict[str, paddle.Tensor]]): Input dict.
165+
model (nn.Layer): NN model.
166+
167+
Returns:
168+
Dict[str, paddle.Tensor]: Result dict for given expression dict.
169+
"""
129170
# model forward
130171
output_dict = model(input_dict)
131172

@@ -140,5 +181,4 @@ def visu_forward(
140181
# clear differentiation cache
141182
clear()
142183

143-
# compute loss for each validator according to its' own output, label and weight
144184
return output_dict

ppsci/utils/reader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from typing import Dict
1919
from typing import Optional
2020
from typing import Tuple
21-
from typing import Union
2221

2322
import meshio
2423
import numpy as np

0 commit comments

Comments
 (0)