Skip to content

Commit d9958d2

Browse files
refine code and remove redundant mkdir_if_not_exit function
1 parent 62d3879 commit d9958d2

File tree

8 files changed

+26
-47
lines changed

8 files changed

+26
-47
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
5858

5959
## 快速开始
6060

61-
参考 [**快速开始**](https://paddlescience-docs.readthedocs.io/zh/latest/zh/quickstart/)
61+
请参考 [**快速开始**](https://paddlescience-docs.readthedocs.io/zh/latest/zh/quickstart/)
6262

6363
## 经典案例
6464

ppsci/metric/l2_rel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class L2Rel(base.Metric):
3333
"""
3434

3535
def __init__(self, keep_batch: bool = False):
36-
if keep_batch is not False:
36+
if keep_batch:
3737
raise ValueError(f"keep_batch should be False, but got {keep_batch}.")
3838
super().__init__(keep_batch)
3939

ppsci/metric/rmse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class RMSE(base.Metric):
4040
"""
4141

4242
def __init__(self, keep_batch: bool = False):
43-
if keep_batch is not False:
43+
if keep_batch:
4444
raise ValueError(f"keep_batch should be False, but got {keep_batch}.")
4545
super().__init__(keep_batch)
4646

ppsci/solver/eval.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
27-
"""Evaluation program by dataset.
27+
"""Evaluate with computing metric on total samples.
2828
2929
Args:
3030
solver (solver.Solver): Main Solver.
@@ -96,12 +96,12 @@ def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
9696
batch_cost = time.perf_counter() - batch_tic
9797
solver.eval_time_info["reader_cost"].update(reader_cost)
9898
solver.eval_time_info["batch_cost"].update(batch_cost)
99-
total_batch_size = sum([v.shape[0] for v in input_dict.values()])
100-
printer.update_eval_loss(solver, loss_dict, total_batch_size)
99+
batch_size = next(iter(input_dict.values())).shape[0]
100+
printer.update_eval_loss(solver, loss_dict, batch_size)
101101
if iter_id == 1 or iter_id % log_freq == 0:
102102
printer.log_eval_info(
103103
solver,
104-
total_batch_size,
104+
batch_size,
105105
epoch_id,
106106
len(_validator.data_loader),
107107
iter_id,
@@ -110,7 +110,7 @@ def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
110110
reader_tic = time.perf_counter()
111111
batch_tic = time.perf_counter()
112112

113-
# gather all data
113+
# concate all data and discard padded sample(s)
114114
for key in all_input:
115115
all_input[key] = paddle.concat(all_input[key])
116116
if len(all_input[key]) > num_samples:
@@ -138,20 +138,18 @@ def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
138138
float(metric_value), num_samples
139139
)
140140

141+
# use the first metric for return value
141142
if target_metric is None:
142143
tmp = metric
143144
while isinstance(tmp, dict):
144145
tmp = next(iter(tmp.values()))
145-
assert isinstance(
146-
tmp, (int, float)
147-
), f"Target metric({type(tmp)}) should be a number"
148-
target_metric = tmp
146+
target_metric = float(tmp)
149147

150148
return target_metric
151149

152150

153151
def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
154-
"""Evaluation program by batch.
152+
"""Evaluate with computing metric by batch, which is memory-efficient.
155153
156154
Args:
157155
solver (solver.Solver): Main Solver.
@@ -181,7 +179,7 @@ def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
181179
for key in solver.eval_time_info:
182180
solver.eval_time_info[key].reset()
183181
reader_cost = time.perf_counter() - reader_tic
184-
total_batch_size = next(iter(input_dict.values())).shape[0]
182+
batch_size = next(iter(input_dict.values())).shape[0]
185183

186184
for v in input_dict.values():
187185
v.stop_gradient = False
@@ -213,11 +211,11 @@ def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
213211
batch_cost = time.perf_counter() - batch_tic
214212
solver.eval_time_info["reader_cost"].update(reader_cost)
215213
solver.eval_time_info["batch_cost"].update(batch_cost)
216-
printer.update_eval_loss(solver, loss_dict, total_batch_size)
214+
printer.update_eval_loss(solver, loss_dict, batch_size)
217215
if iter_id == 1 or iter_id % log_freq == 0:
218216
printer.log_eval_info(
219217
solver,
220-
total_batch_size,
218+
batch_size,
221219
epoch_id,
222220
len(_validator.data_loader),
223221
iter_id,
@@ -226,7 +224,7 @@ def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
226224
reader_tic = time.perf_counter()
227225
batch_tic = time.perf_counter()
228226

229-
# gather all metric
227+
# concate all metric and discard metric of padded sample(s)
230228
for metric_name, metric_dict in metric.items():
231229
for var_name, metric_value in metric_dict.items():
232230
metric_value = paddle.concat(metric_value)[:num_samples]
@@ -239,20 +237,18 @@ def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
239237
)
240238
solver.eval_output_info[metric_str].update(metric_value, num_samples)
241239

240+
# use the first metric for return value
242241
if target_metric is None:
243242
tmp = metric
244243
while isinstance(tmp, dict):
245244
tmp = next(iter(tmp.values()))
246-
assert isinstance(
247-
tmp, (int, float)
248-
), f"Target metric({type(tmp)}) should be a number"
249245
target_metric = tmp
250246

251247
return target_metric
252248

253249

254250
def eval_func(solver, epoch_id: int, log_freq: int) -> float:
255-
"""Evaluation program
251+
"""Evaluation function.
256252
257253
Args:
258254
solver (solver.Solver): Main Solver.

ppsci/solver/solver.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,7 @@ def __init__(
221221
self.train_epoch_func = ppsci.solver.train.train_LBFGS_epoch_func
222222
if self.update_freq != 1:
223223
self.update_freq = 1
224-
logger.warning(
225-
f"Set update_freq from {self.update_freq} to 1 when using L-BFGS optimizer."
226-
)
224+
logger.warning("Set update_freq to to 1 when using L-BFGS optimizer.")
227225
else:
228226
self.train_epoch_func = ppsci.solver.train.train_epoch_func
229227

ppsci/utils/save_load.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import errno
1615
import os
1716
from typing import Any
1817
from typing import Dict
@@ -25,22 +24,6 @@
2524
__all__ = ["load_checkpoint", "save_checkpoint", "load_pretrain"]
2625

2726

28-
def _mkdir_if_not_exist(path):
29-
"""mkdir if not exists, ignore the exception when multiprocess mkdir together
30-
31-
Args:
32-
path (str): Path for makedir
33-
"""
34-
if not os.path.exists(path):
35-
try:
36-
os.makedirs(path)
37-
except OSError as os_err:
38-
if os_err.errno == errno.EEXIST and os.path.isdir(path):
39-
logger.warning(f"{path} already created.")
40-
else:
41-
raise OSError(f"Failed to mkdir {path}.")
42-
43-
4427
def _load_pretrain_from_path(model, path, equation=None):
4528
"""Load pretrained model from given path.
4629
@@ -137,15 +120,15 @@ def save_checkpoint(
137120
model (nn.Layer): Model with parameters.
138121
optimizer (optimizer.Optimizer): Optimizer for model.
139122
grad_scaler (Optional[amp.GradScaler]): GradScaler for AMP. Defaults to None.
140-
metric (Dict[str, Any]): Metric information, such as {"RMSE": ...}.
123+
metric (Dict[str, float]): Metric information, such as {"RMSE": ...}.
141124
model_dir (str): Directory for chekpoint storage.
142125
prefix (str, optional): Prefix for storage. Defaults to "ppsci".
143126
equation (Optional[Dict[str, ppsci.equation.PDE]]): Equations. Defaults to None.
144127
"""
145128
if paddle.distributed.get_rank() != 0:
146129
return
147130
model_dir = os.path.join(model_dir, "checkpoints")
148-
_mkdir_if_not_exist(model_dir)
131+
os.makedirs(model_dir, exist_ok=True)
149132
model_path = os.path.join(model_dir, prefix)
150133

151134
paddle.save(model.state_dict(), f"{model_path}.pdparams")

ppsci/validate/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from ppsci import data
2121
from ppsci import loss
22+
from ppsci import metric
2223

2324

2425
class Validator:
@@ -28,7 +29,7 @@ class Validator:
2829
dataset (io.Dataset): Dataset for validator.
2930
dataloader_cfg (Dict[str, Any]): Dataloader config.
3031
loss (loss.Loss): Loss functor.
31-
metric (Dict[str, Any]): Named metric functors in dict.
32+
metric (Dict[str, metric.Metric]): Named metric functors in dict.
3233
name (str): Name of validator.
3334
"""
3435

@@ -37,7 +38,7 @@ def __init__(
3738
dataset: io.Dataset,
3839
dataloader_cfg: Dict[str, Any],
3940
loss: loss.Loss,
40-
metric: Dict[str, Any],
41+
metric: Dict[str, metric.Metric],
4142
name: str,
4243
):
4344
self.data_loader = data.build_dataloader(dataset, dataloader_cfg)

ppsci/validate/geo_validator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from ppsci import geometry
2828
from ppsci import loss
29+
from ppsci import metric
2930
from ppsci.data import dataset
3031
from ppsci.validate import base
3132

@@ -46,7 +47,7 @@ class GeometryValidator(base.Validator):
4647
geometry. Defaults to "pseudo".
4748
criteria (Optional[Callable]): Criteria for refining specified domain. Defaults to None.
4849
evenly (bool, optional): Whether to use evenly distribution sampling. Defaults to False.
49-
metric (Optional[Dict[str, Any]]): Named metric functors in dict. Defaults to None.
50+
metric (Optional[Dict[str, metric.Metric]]): Named metric functors in dict. Defaults to None.
5051
with_initial (bool, optional): Whether the data contains time t0. Defaults to False.
5152
name (Optional[str]): Name of validator. Defaults to None.
5253
@@ -77,7 +78,7 @@ def __init__(
7778
random: Literal["pseudo", "LHS"] = "pseudo",
7879
criteria: Optional[Callable] = None,
7980
evenly: bool = False,
80-
metric: Optional[Dict[str, Any]] = None,
81+
metric: Optional[Dict[str, metric.Metric]] = None,
8182
with_initial: bool = False,
8283
name: Optional[str] = None,
8384
):

0 commit comments

Comments
 (0)