Skip to content

Commit 02d5fd2

Browse files
optimize code
1 parent c4b345a commit 02d5fd2

File tree

3 files changed

+30
-24
lines changed

3 files changed

+30
-24
lines changed

ppsci/solver/solver.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import visualdl as vdl
3030
from packaging import version
3131
from paddle import amp
32+
from paddle import jit
3233
from paddle import nn
3334
from paddle import optimizer as optim
3435
from paddle.distributed import fleet
@@ -218,6 +219,11 @@ def __init__(
218219
# choosing an appropriate training function for different optimizers
219220
if isinstance(self.optimizer, optim.LBFGS):
220221
self.train_epoch_func = ppsci.solver.train.train_LBFGS_epoch_func
222+
if self.update_freq != 1:
223+
self.update_freq = 1
224+
logger.warning(
225+
f"Set update_freq from {self.update_freq} to 1 when using L-BFGS optimizer."
226+
)
221227
else:
222228
self.train_epoch_func = ppsci.solver.train.train_epoch_func
223229

@@ -511,11 +517,11 @@ def export(self):
511517

512518
input_spec = copy.deepcopy(self.cfg["Export"]["input_shape"])
513519
config.replace_shape_with_inputspec_(input_spec)
514-
static_model = paddle.jit.to_static(self.model, input_spec=input_spec)
520+
static_model = jit.to_static(self.model, input_spec=input_spec)
515521

516522
export_dir = self.cfg["Global"]["save_inference_dir"]
517523
save_path = os.path.join(export_dir, "inference")
518-
paddle.jit.save(static_model, save_path)
524+
jit.save(static_model, save_path)
519525
logger.info(f"The inference model has been exported to {export_dir}")
520526

521527
def autocast_context_manager(self) -> contextlib.AbstractContextManager:

ppsci/solver/train.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
3434
total_loss = 0
3535
loss_dict = misc.Prettydefaultdict(float)
3636
loss_dict["loss"] = 0.0
37-
total_batch_size = []
37+
total_batch_size = 0
3838
reader_cost = 0
3939
batch_cost = 0
4040
reader_tic = time.perf_counter()
@@ -48,7 +48,7 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
4848
for key in solver.train_time_info:
4949
solver.train_time_info[key].reset()
5050
reader_cost += time.perf_counter() - reader_tic
51-
total_batch_size.append(next(iter(input_dict.values())).shape[0])
51+
total_batch_size += next(iter(input_dict.values())).shape[0]
5252

5353
for v in input_dict.values():
5454
v.stop_gradient = False
@@ -80,21 +80,20 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
8080
if iter_id % solver.update_freq == 0:
8181
solver.scaler.minimize(solver.optimizer, total_loss_scaled)
8282
solver.optimizer.clear_grad()
83-
if solver.lr_scheduler is not None and not solver.lr_scheduler.by_epoch:
84-
solver.lr_scheduler.step()
8583
else:
8684
total_loss.backward()
8785
if iter_id % solver.update_freq == 0:
8886
solver.optimizer.step()
8987
solver.optimizer.clear_grad()
90-
if solver.lr_scheduler is not None and not solver.lr_scheduler.by_epoch:
91-
solver.lr_scheduler.step()
88+
89+
# update learning rate by step
90+
if solver.lr_scheduler is not None and not solver.lr_scheduler.by_epoch:
91+
solver.lr_scheduler.step()
9292

9393
batch_cost += time.perf_counter() - batch_tic
9494

9595
# update and log training information
9696
solver.global_step += 1
97-
total_batch_size = sum(total_batch_size)
9897
solver.train_time_info["reader_cost"].update(reader_cost)
9998
solver.train_time_info["batch_cost"].update(batch_cost)
10099
printer.update_train_loss(solver, loss_dict, total_batch_size)
@@ -117,23 +116,26 @@ def train_LBFGS_epoch_func(solver, epoch_id: int, log_freq: int):
117116
for iter_id in range(1, solver.iters_per_epoch + 1):
118117
loss_dict = misc.Prettydefaultdict(float)
119118
loss_dict["loss"] = 0.0
120-
total_batch_size = []
119+
total_batch_size = 0
121120
reader_cost = 0
122121
batch_cost = 0
123122
reader_tic = time.perf_counter()
124123

125-
input_dict_list = []
126-
label_dict_list = []
127-
weight_dict_list = []
124+
input_dicts = []
125+
label_dicts = []
126+
weight_dicts = []
128127
for _, _constraint in solver.constraint.items():
129128
input_dict, label_dict, weight_dict = next(_constraint.data_iter)
130129
reader_cost += time.perf_counter() - reader_tic
131130
for v in input_dict.values():
132131
v.stop_gradient = False
133-
input_dict_list.append(input_dict)
134-
label_dict_list.append(label_dict)
135-
weight_dict_list.append(weight_dict)
136-
total_batch_size.append(next(iter(input_dict.values())).shape[0])
132+
133+
# gather all constraint data into list
134+
input_dicts.append(input_dict)
135+
label_dicts.append(label_dict)
136+
weight_dicts.append(weight_dict)
137+
total_batch_size += next(iter(input_dict.values())).shape[0]
138+
reader_tic = time.perf_counter()
137139

138140
def closure():
139141
"""Forward-backward closure function for LBFGS optimizer.
@@ -147,13 +149,14 @@ def closure():
147149
_constraint.input_keys, _constraint.output_keys, solver.model
148150
)
149151
for output_name, output_formula in _constraint.output_expr.items():
150-
evaluator.add_target_expr(output_formula, output_name)
152+
if output_name in label_dict:
153+
evaluator.add_target_expr(output_formula, output_name)
151154

152155
# forward for every batched data dict
153156
with solver.autocast_context_manager():
154-
output_dict_i = evaluator(input_dict_list[i])
157+
output_dict_i = evaluator(input_dicts[i])
155158
constraint_loss = _constraint.loss(
156-
output_dict_i, label_dict_list[i], weight_dict_list[i]
159+
output_dict_i, label_dicts[i], weight_dicts[i]
157160
)
158161
total_loss += constraint_loss
159162

@@ -165,8 +168,6 @@ def closure():
165168

166169
return total_loss
167170

168-
reader_tic = time.perf_counter()
169-
170171
solver.optimizer.step(closure)
171172
if solver.lr_scheduler is not None and not solver.lr_scheduler.by_epoch:
172173
solver.lr_scheduler.step()
@@ -175,7 +176,6 @@ def closure():
175176

176177
# update and log training information
177178
solver.global_step += 1
178-
total_batch_size = sum(total_batch_size)
179179
solver.train_time_info["reader_cost"].update(reader_cost)
180180
solver.train_time_info["batch_cost"].update(batch_cost)
181181
printer.update_train_loss(solver, loss_dict, total_batch_size)

ppsci/utils/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def combine_array_with_time(x, t):
157157
return tx
158158

159159

160-
def set_random_seed(seed):
160+
def set_random_seed(seed: int):
161161
paddle.seed(seed)
162162
np.random.seed(seed)
163163
random.seed(seed)

0 commit comments

Comments
 (0)