Skip to content

Commit 62d3879

Browse files
unify return type of metric module
1 parent 02d5fd2 commit 62d3879

File tree

10 files changed

+119
-47
lines changed

10 files changed

+119
-47
lines changed

ppsci/metric/anomaly_coef.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,11 @@ def __init__(
6363
unlog: bool = False,
6464
scale: float = 1e-5,
6565
):
66-
super().__init__()
66+
super().__init__(keep_batch)
6767
self.num_lat = num_lat
6868
self.mean = (
6969
None if mean is None else paddle.to_tensor(mean, paddle.get_default_dtype())
7070
)
71-
self.keep_batch = keep_batch
7271
self.variable_dict = variable_dict
7372
self.unlog = unlog
7473
self.scale = scale
@@ -110,14 +109,10 @@ def forward(self, output_dict, label_dict):
110109
if self.keep_batch:
111110
metric_dict[f"{key}.{variable_name}"] = rmse[:, idx]
112111
else:
113-
metric_dict[f"{key}.{variable_name}"] = float(
114-
rmse[:, idx].mean()
115-
)
112+
metric_dict[f"{key}.{variable_name}"] = rmse[:, idx].mean()
116113
else:
117114
if self.keep_batch:
118-
rmse = rmse.mean(axis=1)
119-
metric_dict[key] = rmse
115+
metric_dict[key] = rmse.mean(axis=1)
120116
else:
121-
rmse = rmse.mean()
122-
metric_dict[key] = float(rmse)
117+
metric_dict[key] = rmse.mean()
123118
return metric_dict

ppsci/metric/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
class Metric(nn.Layer):
1919
"""Base class for metric."""
2020

21-
def __init__(self):
21+
def __init__(self, keep_batch: str = False):
2222
super().__init__()
23+
self.keep_batch = keep_batch

ppsci/metric/l2_rel.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,18 @@ class L2Rel(base.Metric):
2424
metric = \dfrac{\Vert x-y \Vert_2}{\Vert y \Vert_2}
2525
$$
2626
27+
Args:
28+
keep_batch (bool, optional): Whether keep batch axis. Defaults to False.
29+
2730
Examples:
2831
>>> import ppsci
2932
>>> metric = ppsci.metric.L2Rel()
3033
"""
3134

32-
def __init__(self):
33-
super().__init__()
35+
def __init__(self, keep_batch: bool = False):
36+
if keep_batch is not False:
37+
raise ValueError(f"keep_batch should be False, but got {keep_batch}.")
38+
super().__init__(keep_batch)
3439

3540
@paddle.no_grad()
3641
def forward(self, output_dict, label_dict):
@@ -39,6 +44,6 @@ def forward(self, output_dict, label_dict):
3944
rel_l2 = paddle.norm(label_dict[key] - output_dict[key]) / paddle.norm(
4045
label_dict[key]
4146
)
42-
metric_dict[key] = float(rel_l2)
47+
metric_dict[key] = rel_l2
4348

4449
return metric_dict

ppsci/metric/mae.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ class MAE(base.Metric):
3434
"""
3535

3636
def __init__(self, keep_batch: bool = False):
37-
super().__init__()
38-
self.keep_batch = keep_batch
37+
super().__init__(keep_batch)
3938

4039
@paddle.no_grad()
4140
def forward(self, output_dict, label_dict):
@@ -45,6 +44,6 @@ def forward(self, output_dict, label_dict):
4544
if self.keep_batch:
4645
metric_dict[key] = mae.mean(axis=tuple(range(1, mae.ndim)))
4746
else:
48-
metric_dict[key] = float(mae.mean())
47+
metric_dict[key] = mae.mean()
4948

5049
return metric_dict

ppsci/metric/mse.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,25 @@ class MSE(base.Metric):
2525
metric = \dfrac{1}{N}\sum\limits_{i=1}^{N}{(x_i-y_i)^2}
2626
$$
2727
28+
Args:
29+
keep_batch (bool, optional): Whether keep batch axis. Defaults to False.
30+
2831
Examples:
2932
>>> import ppsci
3033
>>> metric = ppsci.metric.MSE()
3134
"""
3235

33-
def __init__(self):
34-
super().__init__()
36+
def __init__(self, keep_batch: bool = False):
37+
super().__init__(keep_batch)
3538

3639
@paddle.no_grad()
3740
def forward(self, output_dict, label_dict):
3841
metric_dict = {}
3942
for key in label_dict:
40-
mse = F.mse_loss(output_dict[key], label_dict[key], "mean")
41-
metric_dict[key] = float(mse)
43+
mse = F.mse_loss(output_dict[key], label_dict[key], "none")
44+
if self.keep_batch:
45+
metric_dict[key] = mse.mean(axis=tuple(range(1, mse.ndim)))
46+
else:
47+
metric_dict[key] = mse.mean()
4248

4349
return metric_dict

ppsci/metric/rmse.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,25 @@ class RMSE(base.Metric):
3131
metric = \sqrt{\dfrac{1}{N}\sum\limits_{i=1}^{N}{(x_i-y_i)^2}}
3232
$$
3333
34+
Args:
35+
keep_batch (bool, optional): Whether keep batch axis. Defaults to False.
36+
3437
Examples:
3538
>>> import ppsci
3639
>>> metric = ppsci.metric.RMSE()
3740
"""
3841

39-
def __init__(self):
40-
super().__init__()
42+
def __init__(self, keep_batch: bool = False):
43+
if keep_batch is not False:
44+
raise ValueError(f"keep_batch should be False, but got {keep_batch}.")
45+
super().__init__(keep_batch)
4146

4247
@paddle.no_grad()
4348
def forward(self, output_dict, label_dict):
4449
metric_dict = {}
4550
for key in label_dict:
4651
rmse = F.mse_loss(output_dict[key], label_dict[key], "mean") ** 0.5
47-
metric_dict[key] = float(rmse)
52+
metric_dict[key] = rmse
4853

4954
return metric_dict
5055

@@ -88,18 +93,16 @@ def __init__(
8893
unlog: bool = False,
8994
scale: float = 1e-5,
9095
):
91-
super().__init__()
96+
super().__init__(keep_batch)
9297
self.num_lat = num_lat
9398
self.std = (
9499
None
95100
if std is None
96101
else paddle.to_tensor(std, paddle.get_default_dtype()).reshape((1, -1))
97102
)
98-
self.keep_batch = keep_batch
99103
self.variable_dict = variable_dict
100104
self.unlog = unlog
101105
self.scale = scale
102-
103106
self.weight = self.get_latitude_weight(num_lat)
104107

105108
def get_latitude_weight(self, num_lat: int = 720):
@@ -127,18 +130,10 @@ def forward(self, output_dict, label_dict):
127130
rmse = rmse * self.std
128131
if self.variable_dict is not None:
129132
for variable_name, idx in self.variable_dict.items():
130-
if self.keep_batch:
131-
metric_dict[f"{key}.{variable_name}"] = rmse[:, idx]
132-
else:
133-
metric_dict[f"{key}.{variable_name}"] = float(
134-
rmse[:, idx].mean()
135-
)
133+
metric_dict[f"{key}.{variable_name}"] = (
134+
rmse[:, idx] if self.keep_batch else rmse[:, idx].mean()
135+
)
136136
else:
137-
if self.keep_batch:
138-
rmse = rmse.mean(axis=1)
139-
metric_dict[key] = rmse
140-
else:
141-
rmse = rmse.mean()
142-
metric_dict[key] = float(rmse)
137+
metric_dict[key] = rmse.mean(axis=1) if self.keep_batch else rmse.mean()
143138

144139
return metric_dict

ppsci/solver/eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
134134
solver.eval_output_info[metric_str] = misc.AverageMeter(
135135
metric_str, ".5f"
136136
)
137-
solver.eval_output_info[metric_str].update(metric_value, num_samples)
137+
solver.eval_output_info[metric_str].update(
138+
float(metric_value), num_samples
139+
)
138140

139141
if target_metric is None:
140142
tmp = metric

ppsci/solver/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def closure():
169169
return total_loss
170170

171171
solver.optimizer.step(closure)
172+
173+
# update learning rate by step
172174
if solver.lr_scheduler is not None and not solver.lr_scheduler.by_epoch:
173175
solver.lr_scheduler.step()
174176

ppsci/utils/misc.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
import collections
1616
import random
17+
from typing import Dict
18+
from typing import List
19+
from typing import Tuple
1720

1821
import numpy as np
1922
import paddle
@@ -94,7 +97,16 @@ def __str__(self):
9497
return "".join([str((k, v)) for k, v in self.items()])
9598

9699

97-
def convert_to_dict(array, keys):
100+
def convert_to_dict(array: np.ndarray, keys: Tuple[str, ...]) -> Dict[str, np.ndarray]:
101+
"""Split given array into single channel array at axis -1 in order of given keys.
102+
103+
Args:
104+
array (np.ndarray): Array to be splited.
105+
keys (Tuple[str, ...]):Keys used in split.
106+
107+
Returns:
108+
Dict[str, np.ndarray]: Splited dict.
109+
"""
98110
if array.shape[-1] != len(keys):
99111
raise ValueError(
100112
f"dim of array({array.shape[-1]}) must equal to " f"len(keys)({len(keys)})"
@@ -104,7 +116,9 @@ def convert_to_dict(array, keys):
104116
return {key: split_array[i] for i, key in enumerate(keys)}
105117

106118

107-
def all_gather(tensor, concat=True, axis=0):
119+
def all_gather(
120+
tensor: paddle.Tensor, concat: bool = True, axis: int = 0
121+
) -> List[paddle.Tensor]:
108122
"""Gather tensor from all devices, concatenate them along given axis if specified.
109123
110124
Args:
@@ -122,29 +136,78 @@ def all_gather(tensor, concat=True, axis=0):
122136
return result
123137

124138

125-
def convert_to_array(dict, keys):
139+
def convert_to_array(dict: Dict[str, np.ndarray], keys: Tuple[str, ...]) -> np.ndarray:
140+
"""Concatenate arrays in axis -1 in order of given keys.
141+
142+
Args:
143+
dict (Dict[str, np.ndarray]): Dict contains arrays.
144+
keys (Tuple[str, ...]): Concatenate keys used in concatenation.
145+
146+
Returns:
147+
np.ndarray: Concatenated array.
148+
"""
126149
return np.concatenate([dict[key] for key in keys], axis=-1)
127150

128151

129-
def concat_dict_list(dict_list):
152+
def concat_dict_list(
153+
dict_list: Tuple[Dict[str, np.ndarray], ...]
154+
) -> Dict[str, np.ndarray]:
155+
"""concatenate arrays in tuple of dicts at axis 0.
156+
157+
Args:
158+
dict_list (Tuple[Dict[str, np.ndarray], ...]): Tuple of dicts.
159+
160+
Returns:
161+
Dict[str, np.ndarray]: A dict with concatenated arrays for each key.
162+
"""
130163
ret = {}
131164
for key in dict_list[0].keys():
132165
ret[key] = np.concatenate([_dict[key] for _dict in dict_list], axis=0)
133166
return ret
134167

135168

136-
def stack_dict_list(dict_list):
169+
def stack_dict_list(
170+
dict_list: Tuple[Dict[str, np.ndarray], ...]
171+
) -> Dict[str, np.ndarray]:
172+
"""Stack arrays in tuple of dicts at axis 0.
173+
174+
Args:
175+
dict_list (Tuple[Dict[str, np.ndarray], ...]): Tuple of dicts.
176+
177+
Returns:
178+
Dict[str, np.ndarray]: A dict with stacked arrays for each key.
179+
"""
137180
ret = {}
138181
for key in dict_list[0].keys():
139182
ret[key] = np.stack([_dict[key] for _dict in dict_list], axis=0)
140183
return ret
141184

142185

143-
def typename(object):
186+
def typename(object: object) -> str:
187+
"""Return type name of given object.
188+
189+
Args:
190+
object (object): Python object which is instantiated from a class.
191+
192+
Returns:
193+
str: Class name of given object.
194+
"""
144195
return object.__class__.__name__
145196

146197

147-
def combine_array_with_time(x, t):
198+
def combine_array_with_time(x: np.ndarray, t: Tuple[int, ...]) -> np.ndarray:
199+
"""Combine given data x with time sequence t.
200+
Given x with shape (N, D) and t with shape (T, ),
201+
this function will repeat t_i for N times and will concat it with data x for each t_i in t,
202+
finally return the stacked result, whic is of shape (NxT, D+1).
203+
204+
Args:
205+
x (np.ndarray): Points data with shape (N, D).
206+
t (Tuple[int, ...]): Time sequence with shape (T, ).
207+
208+
Returns:
209+
np.ndarray: Combined data with shape of (NxT, D+1).
210+
"""
148211
nx = len(x)
149212
tx = []
150213
for ti in t:
@@ -158,6 +221,11 @@ def combine_array_with_time(x, t):
158221

159222

160223
def set_random_seed(seed: int):
224+
"""Set numpy, random, paddle random_seed to given seed.
225+
226+
Args:
227+
seed (int): Random seed.
228+
"""
161229
paddle.seed(seed)
162230
np.random.seed(seed)
163231
random.seed(seed)

ppsci/visualize/visualizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Tuple
2020

2121
import numpy as np
22-
import paddle
2322

2423
from ppsci.visualize import base
2524
from ppsci.visualize import plot

0 commit comments

Comments
 (0)