Skip to content

Commit f3f554c

Browse files
Merge pull request #330 from HydrogenSulfate/opt_code2
optimize code
2 parents c4b345a + aa3255c commit f3f554c

File tree

15 files changed

+170
-113
lines changed

15 files changed

+170
-113
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
5858

5959
## 快速开始
6060

61-
参考 [**快速开始**](https://paddlescience-docs.readthedocs.io/zh/latest/zh/quickstart/)
61+
请参考 [**快速开始**](https://paddlescience-docs.readthedocs.io/zh/latest/zh/quickstart/)
6262

6363
## 经典案例
6464

ppsci/metric/anomaly_coef.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,11 @@ def __init__(
6363
unlog: bool = False,
6464
scale: float = 1e-5,
6565
):
66-
super().__init__()
66+
super().__init__(keep_batch)
6767
self.num_lat = num_lat
6868
self.mean = (
6969
None if mean is None else paddle.to_tensor(mean, paddle.get_default_dtype())
7070
)
71-
self.keep_batch = keep_batch
7271
self.variable_dict = variable_dict
7372
self.unlog = unlog
7473
self.scale = scale
@@ -110,14 +109,10 @@ def forward(self, output_dict, label_dict):
110109
if self.keep_batch:
111110
metric_dict[f"{key}.{variable_name}"] = rmse[:, idx]
112111
else:
113-
metric_dict[f"{key}.{variable_name}"] = float(
114-
rmse[:, idx].mean()
115-
)
112+
metric_dict[f"{key}.{variable_name}"] = rmse[:, idx].mean()
116113
else:
117114
if self.keep_batch:
118-
rmse = rmse.mean(axis=1)
119-
metric_dict[key] = rmse
115+
metric_dict[key] = rmse.mean(axis=1)
120116
else:
121-
rmse = rmse.mean()
122-
metric_dict[key] = float(rmse)
117+
metric_dict[key] = rmse.mean()
123118
return metric_dict

ppsci/metric/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
class Metric(nn.Layer):
1919
"""Base class for metric."""
2020

21-
def __init__(self):
21+
def __init__(self, keep_batch: bool = False):
2222
super().__init__()
23+
self.keep_batch = keep_batch

ppsci/metric/l2_rel.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,18 @@ class L2Rel(base.Metric):
2424
metric = \dfrac{\Vert x-y \Vert_2}{\Vert y \Vert_2}
2525
$$
2626
27+
Args:
28+
keep_batch (bool, optional): Whether keep batch axis. Defaults to False.
29+
2730
Examples:
2831
>>> import ppsci
2932
>>> metric = ppsci.metric.L2Rel()
3033
"""
3134

32-
def __init__(self):
33-
super().__init__()
35+
def __init__(self, keep_batch: bool = False):
36+
if keep_batch:
37+
raise ValueError(f"keep_batch should be False, but got {keep_batch}.")
38+
super().__init__(keep_batch)
3439

3540
@paddle.no_grad()
3641
def forward(self, output_dict, label_dict):
@@ -39,6 +44,6 @@ def forward(self, output_dict, label_dict):
3944
rel_l2 = paddle.norm(label_dict[key] - output_dict[key]) / paddle.norm(
4045
label_dict[key]
4146
)
42-
metric_dict[key] = float(rel_l2)
47+
metric_dict[key] = rel_l2
4348

4449
return metric_dict

ppsci/metric/mae.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ class MAE(base.Metric):
3434
"""
3535

3636
def __init__(self, keep_batch: bool = False):
37-
super().__init__()
38-
self.keep_batch = keep_batch
37+
super().__init__(keep_batch)
3938

4039
@paddle.no_grad()
4140
def forward(self, output_dict, label_dict):
@@ -45,6 +44,6 @@ def forward(self, output_dict, label_dict):
4544
if self.keep_batch:
4645
metric_dict[key] = mae.mean(axis=tuple(range(1, mae.ndim)))
4746
else:
48-
metric_dict[key] = float(mae.mean())
47+
metric_dict[key] = mae.mean()
4948

5049
return metric_dict

ppsci/metric/mse.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,25 @@ class MSE(base.Metric):
2525
metric = \dfrac{1}{N}\sum\limits_{i=1}^{N}{(x_i-y_i)^2}
2626
$$
2727
28+
Args:
29+
keep_batch (bool, optional): Whether keep batch axis. Defaults to False.
30+
2831
Examples:
2932
>>> import ppsci
3033
>>> metric = ppsci.metric.MSE()
3134
"""
3235

33-
def __init__(self):
34-
super().__init__()
36+
def __init__(self, keep_batch: bool = False):
37+
super().__init__(keep_batch)
3538

3639
@paddle.no_grad()
3740
def forward(self, output_dict, label_dict):
3841
metric_dict = {}
3942
for key in label_dict:
40-
mse = F.mse_loss(output_dict[key], label_dict[key], "mean")
41-
metric_dict[key] = float(mse)
43+
mse = F.mse_loss(output_dict[key], label_dict[key], "none")
44+
if self.keep_batch:
45+
metric_dict[key] = mse.mean(axis=tuple(range(1, mse.ndim)))
46+
else:
47+
metric_dict[key] = mse.mean()
4248

4349
return metric_dict

ppsci/metric/rmse.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,25 @@ class RMSE(base.Metric):
3131
metric = \sqrt{\dfrac{1}{N}\sum\limits_{i=1}^{N}{(x_i-y_i)^2}}
3232
$$
3333
34+
Args:
35+
keep_batch (bool, optional): Whether keep batch axis. Defaults to False.
36+
3437
Examples:
3538
>>> import ppsci
3639
>>> metric = ppsci.metric.RMSE()
3740
"""
3841

39-
def __init__(self):
40-
super().__init__()
42+
def __init__(self, keep_batch: bool = False):
43+
if keep_batch:
44+
raise ValueError(f"keep_batch should be False, but got {keep_batch}.")
45+
super().__init__(keep_batch)
4146

4247
@paddle.no_grad()
4348
def forward(self, output_dict, label_dict):
4449
metric_dict = {}
4550
for key in label_dict:
4651
rmse = F.mse_loss(output_dict[key], label_dict[key], "mean") ** 0.5
47-
metric_dict[key] = float(rmse)
52+
metric_dict[key] = rmse
4853

4954
return metric_dict
5055

@@ -88,18 +93,16 @@ def __init__(
8893
unlog: bool = False,
8994
scale: float = 1e-5,
9095
):
91-
super().__init__()
96+
super().__init__(keep_batch)
9297
self.num_lat = num_lat
9398
self.std = (
9499
None
95100
if std is None
96101
else paddle.to_tensor(std, paddle.get_default_dtype()).reshape((1, -1))
97102
)
98-
self.keep_batch = keep_batch
99103
self.variable_dict = variable_dict
100104
self.unlog = unlog
101105
self.scale = scale
102-
103106
self.weight = self.get_latitude_weight(num_lat)
104107

105108
def get_latitude_weight(self, num_lat: int = 720):
@@ -127,18 +130,10 @@ def forward(self, output_dict, label_dict):
127130
rmse = rmse * self.std
128131
if self.variable_dict is not None:
129132
for variable_name, idx in self.variable_dict.items():
130-
if self.keep_batch:
131-
metric_dict[f"{key}.{variable_name}"] = rmse[:, idx]
132-
else:
133-
metric_dict[f"{key}.{variable_name}"] = float(
134-
rmse[:, idx].mean()
135-
)
133+
metric_dict[f"{key}.{variable_name}"] = (
134+
rmse[:, idx] if self.keep_batch else rmse[:, idx].mean()
135+
)
136136
else:
137-
if self.keep_batch:
138-
rmse = rmse.mean(axis=1)
139-
metric_dict[key] = rmse
140-
else:
141-
rmse = rmse.mean()
142-
metric_dict[key] = float(rmse)
137+
metric_dict[key] = rmse.mean(axis=1) if self.keep_batch else rmse.mean()
143138

144139
return metric_dict

ppsci/solver/eval.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
27-
"""Evaluation program by dataset.
27+
"""Evaluate with computing metric on total samples.
2828
2929
Args:
3030
solver (solver.Solver): Main Solver.
@@ -96,12 +96,12 @@ def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
9696
batch_cost = time.perf_counter() - batch_tic
9797
solver.eval_time_info["reader_cost"].update(reader_cost)
9898
solver.eval_time_info["batch_cost"].update(batch_cost)
99-
total_batch_size = sum([v.shape[0] for v in input_dict.values()])
100-
printer.update_eval_loss(solver, loss_dict, total_batch_size)
99+
batch_size = next(iter(input_dict.values())).shape[0]
100+
printer.update_eval_loss(solver, loss_dict, batch_size)
101101
if iter_id == 1 or iter_id % log_freq == 0:
102102
printer.log_eval_info(
103103
solver,
104-
total_batch_size,
104+
batch_size,
105105
epoch_id,
106106
len(_validator.data_loader),
107107
iter_id,
@@ -110,7 +110,7 @@ def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
110110
reader_tic = time.perf_counter()
111111
batch_tic = time.perf_counter()
112112

113-
# gather all data
113+
# concate all data and discard padded sample(s)
114114
for key in all_input:
115115
all_input[key] = paddle.concat(all_input[key])
116116
if len(all_input[key]) > num_samples:
@@ -134,22 +134,22 @@ def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
134134
solver.eval_output_info[metric_str] = misc.AverageMeter(
135135
metric_str, ".5f"
136136
)
137-
solver.eval_output_info[metric_str].update(metric_value, num_samples)
137+
solver.eval_output_info[metric_str].update(
138+
float(metric_value), num_samples
139+
)
138140

141+
# use the first metric for return value
139142
if target_metric is None:
140143
tmp = metric
141144
while isinstance(tmp, dict):
142145
tmp = next(iter(tmp.values()))
143-
assert isinstance(
144-
tmp, (int, float)
145-
), f"Target metric({type(tmp)}) should be a number"
146-
target_metric = tmp
146+
target_metric = float(tmp)
147147

148148
return target_metric
149149

150150

151151
def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
152-
"""Evaluation program by batch.
152+
"""Evaluate with computing metric by batch, which is memory-efficient.
153153
154154
Args:
155155
solver (solver.Solver): Main Solver.
@@ -179,7 +179,7 @@ def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
179179
for key in solver.eval_time_info:
180180
solver.eval_time_info[key].reset()
181181
reader_cost = time.perf_counter() - reader_tic
182-
total_batch_size = next(iter(input_dict.values())).shape[0]
182+
batch_size = next(iter(input_dict.values())).shape[0]
183183

184184
for v in input_dict.values():
185185
v.stop_gradient = False
@@ -211,11 +211,11 @@ def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
211211
batch_cost = time.perf_counter() - batch_tic
212212
solver.eval_time_info["reader_cost"].update(reader_cost)
213213
solver.eval_time_info["batch_cost"].update(batch_cost)
214-
printer.update_eval_loss(solver, loss_dict, total_batch_size)
214+
printer.update_eval_loss(solver, loss_dict, batch_size)
215215
if iter_id == 1 or iter_id % log_freq == 0:
216216
printer.log_eval_info(
217217
solver,
218-
total_batch_size,
218+
batch_size,
219219
epoch_id,
220220
len(_validator.data_loader),
221221
iter_id,
@@ -224,7 +224,7 @@ def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
224224
reader_tic = time.perf_counter()
225225
batch_tic = time.perf_counter()
226226

227-
# gather all metric
227+
# concate all metric and discard metric of padded sample(s)
228228
for metric_name, metric_dict in metric.items():
229229
for var_name, metric_value in metric_dict.items():
230230
metric_value = paddle.concat(metric_value)[:num_samples]
@@ -237,20 +237,18 @@ def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
237237
)
238238
solver.eval_output_info[metric_str].update(metric_value, num_samples)
239239

240+
# use the first metric for return value
240241
if target_metric is None:
241242
tmp = metric
242243
while isinstance(tmp, dict):
243244
tmp = next(iter(tmp.values()))
244-
assert isinstance(
245-
tmp, (int, float)
246-
), f"Target metric({type(tmp)}) should be a number"
247245
target_metric = tmp
248246

249247
return target_metric
250248

251249

252250
def eval_func(solver, epoch_id: int, log_freq: int) -> float:
253-
"""Evaluation program
251+
"""Evaluation function.
254252
255253
Args:
256254
solver (solver.Solver): Main Solver.

ppsci/solver/solver.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import visualdl as vdl
3030
from packaging import version
3131
from paddle import amp
32+
from paddle import jit
3233
from paddle import nn
3334
from paddle import optimizer as optim
3435
from paddle.distributed import fleet
@@ -218,6 +219,9 @@ def __init__(
218219
# choosing an appropriate training function for different optimizers
219220
if isinstance(self.optimizer, optim.LBFGS):
220221
self.train_epoch_func = ppsci.solver.train.train_LBFGS_epoch_func
222+
if self.update_freq != 1:
223+
self.update_freq = 1
224+
logger.warning("Set update_freq to to 1 when using L-BFGS optimizer.")
221225
else:
222226
self.train_epoch_func = ppsci.solver.train.train_epoch_func
223227

@@ -511,11 +515,11 @@ def export(self):
511515

512516
input_spec = copy.deepcopy(self.cfg["Export"]["input_shape"])
513517
config.replace_shape_with_inputspec_(input_spec)
514-
static_model = paddle.jit.to_static(self.model, input_spec=input_spec)
518+
static_model = jit.to_static(self.model, input_spec=input_spec)
515519

516520
export_dir = self.cfg["Global"]["save_inference_dir"]
517521
save_path = os.path.join(export_dir, "inference")
518-
paddle.jit.save(static_model, save_path)
522+
jit.save(static_model, save_path)
519523
logger.info(f"The inference model has been exported to {export_dir}")
520524

521525
def autocast_context_manager(self) -> contextlib.AbstractContextManager:

0 commit comments

Comments
 (0)