Skip to content

Commit f404ccb

Browse files
correct LBFGS code
1 parent 8d0ca82 commit f404ccb

File tree

4 files changed

+28
-26
lines changed

4 files changed

+28
-26
lines changed

ppsci/solver/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def eval_func(solver, epoch_id: int, log_freq: int) -> float:
6666
evaluator.add_target_expr(output_formula, output_name)
6767

6868
# forward
69-
with solver._autocast_context_manager():
69+
with solver.autocast_context_manager():
7070
output_dict = evaluator(input_dict)
7171
validator_loss = _validator.loss(output_dict, label_dict, weight_dict)
7272
loss_dict[f"loss({_validator.name})"] = float(validator_loss)

ppsci/solver/solver.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@
2727
import visualdl as vdl
2828
from packaging import version
2929
from paddle import amp
30-
from paddle import incubate
3130
from paddle import nn
32-
from paddle import optimizer
31+
from paddle import optimizer as optim
3332
from paddle.distributed import fleet
3433
from typing_extensions import Literal
3534

@@ -100,8 +99,8 @@ def __init__(
10099
model: nn.Layer,
101100
constraint: Optional[Dict[str, ppsci.constraint.Constraint]] = None,
102101
output_dir: str = "./output/",
103-
optimizer: Optional[optimizer.Optimizer] = None,
104-
lr_scheduler: Optional[optimizer.lr.LRScheduler] = None,
102+
optimizer: Optional[optim.Optimizer] = None,
103+
lr_scheduler: Optional[optim.lr.LRScheduler] = None,
105104
epochs: int = 5,
106105
iters_per_epoch: int = 20,
107106
update_freq: int = 1,
@@ -205,10 +204,10 @@ def __init__(
205204
self.best_metric.update(loaded_metric)
206205

207206
# choosing an appropriate training function for different optimizers
208-
if not isinstance(self.optimizer, incubate.optimizer.LBFGS):
209-
self.train_epoch_func = ppsci.solver.train.train_epoch_func
210-
else:
207+
if isinstance(self.optimizer, optim.LBFGS):
211208
self.train_epoch_func = ppsci.solver.train.train_LBFGS_epoch_func
209+
else:
210+
self.train_epoch_func = ppsci.solver.train.train_epoch_func
212211

213212
# decorate model(s) and optimizer(s) for AMP
214213
if self.use_amp:
@@ -471,7 +470,7 @@ def predict(
471470
batch_input_dict[key].stop_gradient = False
472471

473472
# forward
474-
with self._autocast_context_manager():
473+
with self.autocast_context_manager():
475474
batch_output_dict = self.model(batch_input_dict)
476475

477476
# collect batch data
@@ -501,7 +500,7 @@ def export(self):
501500
paddle.jit.save(static_model, save_path)
502501
logger.info(f"The inference model has been exported to {export_dir}.")
503502

504-
def _autocast_context_manager(self) -> contextlib.AbstractContextManager:
503+
def autocast_context_manager(self) -> contextlib.AbstractContextManager:
505504
"""Autocast context manager for Auto Mix Precision.
506505
507506
Returns:

ppsci/solver/train.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
5959
evaluator.add_target_expr(output_formula, output_name)
6060

6161
# forward for every constraint
62-
with solver._autocast_context_manager():
62+
with solver.autocast_context_manager():
6363
output_dict = evaluator(input_dict)
6464
constraint_loss = _constraint.loss(output_dict, label_dict, weight_dict)
6565
total_loss += constraint_loss
@@ -114,16 +114,16 @@ def train_LBFGS_epoch_func(solver, epoch_id: int, log_freq: int):
114114
batch_tic = time.perf_counter()
115115

116116
for iter_id in range(1, solver.iters_per_epoch + 1):
117-
reader_cost = 0
118-
batch_cost = 0
119117
loss_dict = misc.Prettydefaultdict(float)
120118
loss_dict["loss"] = 0.0
119+
total_batch_size = []
120+
reader_cost = 0
121+
batch_cost = 0
122+
reader_tic = time.perf_counter()
123+
121124
input_dict_list = []
122125
label_dict_list = []
123126
weight_dict_list = []
124-
batch_cost = 0
125-
total_batch_size = []
126-
reader_tic = time.perf_counter()
127127
for _, _constraint in solver.constraint.items():
128128
input_dict, label_dict, weight_dict = next(_constraint.data_iter)
129129
reader_cost += time.perf_counter() - reader_tic
@@ -133,10 +133,9 @@ def train_LBFGS_epoch_func(solver, epoch_id: int, log_freq: int):
133133
label_dict_list.append(label_dict)
134134
weight_dict_list.append(weight_dict)
135135
total_batch_size.append(next(iter(input_dict.values())).shape[0])
136-
total_batch_size = sum(total_batch_size)
137136

138137
def closure():
139-
"""Closure function for LBFGS optimizer.
138+
"""Forward-backward closure function for LBFGS optimizer.
140139
141140
Returns:
142141
Tensor: Computed loss.
@@ -149,21 +148,25 @@ def closure():
149148
for output_name, output_formula in _constraint.output_expr.items():
150149
evaluator.add_target_expr(output_formula, output_name)
151150

152-
# forward for every constraint
153-
output_dict_i = evaluator(input_dict_list[i])
154-
constraint_loss = _constraint.loss(
155-
output_dict_i, label_dict_list[i], weight_dict_list[i]
156-
)
157-
total_loss += constraint_loss
151+
# forward for every batched data dict
152+
with solver.autocast_context_manager():
153+
output_dict_i = evaluator(input_dict_list[i])
154+
constraint_loss = _constraint.loss(
155+
output_dict_i, label_dict_list[i], weight_dict_list[i]
156+
)
157+
total_loss += constraint_loss
158158

159159
loss_dict[_constraint.name] += float(constraint_loss)
160160

161161
total_loss.backward()
162+
loss_dict["loss"] = float(total_loss)
162163

163164
return total_loss
164165

166+
reader_tic = time.perf_counter()
167+
165168
solver.optimizer.step(closure)
166-
if not getattr(solver.lr_scheduler, "by_epoch", False):
169+
if solver.lr_scheduler is not None and not solver.lr_scheduler.by_epoch:
167170
solver.lr_scheduler.step()
168171

169172
batch_cost += time.perf_counter() - batch_tic

ppsci/solver/visu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def visualize_func(solver, epoch_id: int):
6363
evaluator.add_target_expr(output_expr, output_key)
6464

6565
# forward
66-
with solver._autocast_context_manager():
66+
with solver.autocast_context_manager():
6767
batch_output_dict = evaluator(batch_input_dict)
6868

6969
# collect batch data

0 commit comments

Comments
 (0)