Skip to content

Commit d0a78a1

Browse files
Merge pull request #302 from HydrogenSulfate/add_LBFGS_train
correct LBFGS code
2 parents 508086c + 3f8bfeb commit d0a78a1

File tree

4 files changed

+36
-30
lines changed

4 files changed

+36
-30
lines changed

ppsci/solver/eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
6666
evaluator.add_target_expr(output_formula, output_name)
6767

6868
# forward
69-
with solver._autocast_context_manager(), solver._no_grad_context_manager():
69+
with solver.autocast_context_manager(), solver.no_grad_context_manager():
7070
output_dict = evaluator(input_dict)
7171
validator_loss = _validator.loss(output_dict, label_dict, weight_dict)
7272

@@ -189,7 +189,7 @@ def eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
189189
evaluator.add_target_expr(output_formula, output_name)
190190

191191
# forward
192-
with solver._autocast_context_manager(), solver._no_grad_context_manager():
192+
with solver.autocast_context_manager(), solver.no_grad_context_manager():
193193
output_dict = evaluator(input_dict)
194194
validator_loss = _validator.loss(output_dict, label_dict, weight_dict)
195195

ppsci/solver/solver.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,16 @@
2121
from typing import Any
2222
from typing import Dict
2323
from typing import Optional
24+
from typing import Union
2425

26+
import numpy as np
2527
import paddle
2628
import paddle.distributed as dist
2729
import visualdl as vdl
2830
from packaging import version
2931
from paddle import amp
30-
from paddle import incubate
3132
from paddle import nn
32-
from paddle import optimizer
33+
from paddle import optimizer as optim
3334
from paddle.distributed import fleet
3435
from typing_extensions import Literal
3536

@@ -103,8 +104,8 @@ def __init__(
103104
model: nn.Layer,
104105
constraint: Optional[Dict[str, ppsci.constraint.Constraint]] = None,
105106
output_dir: str = "./output/",
106-
optimizer: Optional[optimizer.Optimizer] = None,
107-
lr_scheduler: Optional[optimizer.lr.LRScheduler] = None,
107+
optimizer: Optional[optim.Optimizer] = None,
108+
lr_scheduler: Optional[optim.lr.LRScheduler] = None,
108109
epochs: int = 5,
109110
iters_per_epoch: int = 20,
110111
update_freq: int = 1,
@@ -215,10 +216,10 @@ def __init__(
215216
self.best_metric.update(loaded_metric)
216217

217218
# choosing an appropriate training function for different optimizers
218-
if not isinstance(self.optimizer, incubate.optimizer.LBFGS):
219-
self.train_epoch_func = ppsci.solver.train.train_epoch_func
220-
else:
219+
if isinstance(self.optimizer, optim.LBFGS):
221220
self.train_epoch_func = ppsci.solver.train.train_LBFGS_epoch_func
221+
else:
222+
self.train_epoch_func = ppsci.solver.train.train_epoch_func
222223

223224
# decorate model(s) and optimizer(s) for AMP
224225
if self.use_amp:
@@ -445,12 +446,14 @@ def visualize(self, epoch_id: int = 0):
445446

446447
@paddle.no_grad()
447448
def predict(
448-
self, input_dict: Dict[str, paddle.Tensor], batch_size: int = 64
449+
self,
450+
input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
451+
batch_size: int = 64,
449452
) -> Dict[str, paddle.Tensor]:
450453
"""Pure prediction using model.forward(...), support single device prediction yet.
451454
452455
Args:
453-
input_dict (Dict[str, paddle.Tensor]): Input data in dict.
456+
input_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Input data in dict.
454457
batch_size (int, optional): Predicting by batch size. Defaults to 64.
455458
456459
Returns:
@@ -485,7 +488,7 @@ def predict(
485488
batch_input_dict[key].stop_gradient = False
486489

487490
# forward
488-
with self._autocast_context_manager():
491+
with self.autocast_context_manager():
489492
batch_output_dict = self.model(batch_input_dict)
490493

491494
# collect batch data
@@ -515,7 +518,7 @@ def export(self):
515518
paddle.jit.save(static_model, save_path)
516519
logger.info(f"The inference model has been exported to {export_dir}.")
517520

518-
def _autocast_context_manager(self) -> contextlib.AbstractContextManager:
521+
def autocast_context_manager(self) -> contextlib.AbstractContextManager:
519522
"""Autocast context manager for Auto Mix Precision.
520523
521524
Returns:
@@ -532,7 +535,7 @@ def _autocast_context_manager(self) -> contextlib.AbstractContextManager:
532535

533536
return ctx_manager
534537

535-
def _no_grad_context_manager(self) -> contextlib.AbstractContextManager:
538+
def no_grad_context_manager(self) -> contextlib.AbstractContextManager:
536539
"""No grad manager.
537540
538541
Returns:

ppsci/solver/train.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
5959
evaluator.add_target_expr(output_formula, output_name)
6060

6161
# forward for every constraint
62-
with solver._autocast_context_manager():
62+
with solver.autocast_context_manager():
6363
output_dict = evaluator(input_dict)
6464
constraint_loss = _constraint.loss(output_dict, label_dict, weight_dict)
6565
total_loss += constraint_loss
@@ -114,16 +114,16 @@ def train_LBFGS_epoch_func(solver, epoch_id: int, log_freq: int):
114114
batch_tic = time.perf_counter()
115115

116116
for iter_id in range(1, solver.iters_per_epoch + 1):
117-
reader_cost = 0
118-
batch_cost = 0
119117
loss_dict = misc.Prettydefaultdict(float)
120118
loss_dict["loss"] = 0.0
119+
total_batch_size = []
120+
reader_cost = 0
121+
batch_cost = 0
122+
reader_tic = time.perf_counter()
123+
121124
input_dict_list = []
122125
label_dict_list = []
123126
weight_dict_list = []
124-
batch_cost = 0
125-
total_batch_size = []
126-
reader_tic = time.perf_counter()
127127
for _, _constraint in solver.constraint.items():
128128
input_dict, label_dict, weight_dict = next(_constraint.data_iter)
129129
reader_cost += time.perf_counter() - reader_tic
@@ -133,10 +133,9 @@ def train_LBFGS_epoch_func(solver, epoch_id: int, log_freq: int):
133133
label_dict_list.append(label_dict)
134134
weight_dict_list.append(weight_dict)
135135
total_batch_size.append(next(iter(input_dict.values())).shape[0])
136-
total_batch_size = sum(total_batch_size)
137136

138137
def closure():
139-
"""Closure function for LBFGS optimizer.
138+
"""Forward-backward closure function for LBFGS optimizer.
140139
141140
Returns:
142141
Tensor: Computed loss.
@@ -149,21 +148,25 @@ def closure():
149148
for output_name, output_formula in _constraint.output_expr.items():
150149
evaluator.add_target_expr(output_formula, output_name)
151150

152-
# forward for every constraint
153-
output_dict_i = evaluator(input_dict_list[i])
154-
constraint_loss = _constraint.loss(
155-
output_dict_i, label_dict_list[i], weight_dict_list[i]
156-
)
157-
total_loss += constraint_loss
151+
# forward for every batched data dict
152+
with solver.autocast_context_manager():
153+
output_dict_i = evaluator(input_dict_list[i])
154+
constraint_loss = _constraint.loss(
155+
output_dict_i, label_dict_list[i], weight_dict_list[i]
156+
)
157+
total_loss += constraint_loss
158158

159159
loss_dict[_constraint.name] += float(constraint_loss)
160160

161161
total_loss.backward()
162+
loss_dict["loss"] = float(total_loss)
162163

163164
return total_loss
164165

166+
reader_tic = time.perf_counter()
167+
165168
solver.optimizer.step(closure)
166-
if not getattr(solver.lr_scheduler, "by_epoch", False):
169+
if solver.lr_scheduler is not None and not solver.lr_scheduler.by_epoch:
167170
solver.lr_scheduler.step()
168171

169172
batch_cost += time.perf_counter() - batch_tic

ppsci/solver/visu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def visualize_func(solver, epoch_id: int):
6363
evaluator.add_target_expr(output_expr, output_key)
6464

6565
# forward
66-
with solver._autocast_context_manager():
66+
with solver.autocast_context_manager():
6767
batch_output_dict = evaluator(batch_input_dict)
6868

6969
# collect batch data

0 commit comments

Comments
 (0)