Skip to content

Commit 1c0221e

Browse files
change 'num_timestamp' to 'num_timestamps'
1 parent 6493522 commit 1c0221e

File tree

6 files changed

+71
-71
lines changed

6 files changed

+71
-71
lines changed

ppsci/arch/afno.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ class AFNONet(base.Arch):
422422
num_blocks (int, optional): Number of blocks. Defaults to 8.
423423
sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01.
424424
hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0.
425-
num_timestamp (int, optional): Number of timestamp. Defaults to 1.
425+
num_timestamps (int, optional): Number of timestamp. Defaults to 1.
426426
427427
Examples:
428428
>>> import ppsci
@@ -445,7 +445,7 @@ def __init__(
445445
num_blocks: int = 8,
446446
sparsity_threshold: float = 0.01,
447447
hard_thresholding_fraction: float = 1.0,
448-
num_timestamp: int = 1,
448+
num_timestamps: int = 1,
449449
):
450450
super().__init__()
451451
self.input_keys = input_keys
@@ -457,7 +457,7 @@ def __init__(
457457
self.out_channels = out_channels
458458
self.embed_dim = embed_dim
459459
self.num_blocks = num_blocks
460-
self.num_timestamp = num_timestamp
460+
self.num_timestamps = num_timestamps
461461
norm_layer = partial(nn.LayerNorm, epsilon=1e-6)
462462

463463
self.patch_embed = PatchEmbed(
@@ -555,7 +555,7 @@ def forward(self, x):
555555

556556
y = []
557557
input = x
558-
for i in range(self.num_timestamp):
558+
for i in range(self.num_timestamps):
559559
out = self.forward_tensor(input)
560560
y.append(out)
561561
input = out

ppsci/data/dataset/era5_dataset.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class ERA5Dataset(io.Dataset):
3434
precip_file_path (Optional[str]): Precipitation data set path. Defaults to None.
3535
weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None.
3636
vars_channel (Optional[Tuple[int, ...]]): The variable channel index in ERA5 dataset. Defaults to None.
37-
num_label_timestamp (int, optional): Number of timestamp of label. Defaults to 1.
37+
num_label_timestamps (int, optional): Number of timestamp of label. Defaults to 1.
3838
transforms (Optional[vision.Compose]): Compose object contains sample wise
3939
transform(s). Defaults to None.
4040
training (bool, optional): Whether in train mode. Defaults to True.
@@ -56,7 +56,7 @@ def __init__(
5656
precip_file_path: Optional[str] = None,
5757
weight_dict: Optional[Dict[str, float]] = None,
5858
vars_channel: Optional[Tuple[int, ...]] = None,
59-
num_label_timestamp: int = 1,
59+
num_label_timestamps: int = 1,
6060
transforms: Optional[vision.Compose] = None,
6161
training: bool = True,
6262
):
@@ -74,7 +74,7 @@ def __init__(
7474
self.vars_channel = (
7575
vars_channel if vars_channel is not None else [i for i in range(20)]
7676
)
77-
self.num_label_timestamp = num_label_timestamp
77+
self.num_label_timestamps = num_label_timestamps
7878
self.transforms = transforms
7979
self.training = training
8080

@@ -102,9 +102,9 @@ def __getitem__(self, global_idx):
102102
local_idx = global_idx % self.n_samples_per_year
103103
step = 0 if local_idx >= self.n_samples_per_year - 1 else 1
104104

105-
if self.num_label_timestamp > 1:
106-
if local_idx >= self.n_samples_per_year - self.num_label_timestamp:
107-
local_idx = self.n_samples_per_year - self.num_label_timestamp - 1
105+
if self.num_label_timestamps > 1:
106+
if local_idx >= self.n_samples_per_year - self.num_label_timestamps:
107+
local_idx = self.n_samples_per_year - self.num_label_timestamps - 1
108108

109109
input_file = self.files[year_idx]
110110
label_file = (
@@ -124,7 +124,7 @@ def __getitem__(self, global_idx):
124124

125125
input_item = {self.input_keys[0]: input_file[input_idx, self.vars_channel]}
126126
label_item = {}
127-
for i in range(self.num_label_timestamp):
127+
for i in range(self.num_label_timestamps):
128128
if self.precip_file_path is not None:
129129
label_item[self.label_keys[i]] = np.expand_dims(
130130
label_file[label_idx + i], 0

ppsci/geometry/timedomain.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ def __init__(
6161
if time_step is not None:
6262
if time_step <= 0:
6363
raise ValueError(f"time_step({time_step}) must be larger than 0.")
64-
self.num_timestamp = int(np.ceil((t1 - t0) / time_step)) + 1
64+
self.num_timestamps = int(np.ceil((t1 - t0) / time_step)) + 1
6565
elif timestamps is not None:
66-
self.num_timestamp = len(timestamps)
66+
self.num_timestamps = len(timestamps)
6767

6868
def on_initial(self, t):
6969
return np.isclose(t, self.t0).flatten()
@@ -117,7 +117,7 @@ def uniform_points(self, n, boundary=True):
117117
nx = int(np.ceil(n / nt))
118118
elif self.timedomain.timestamps is not None:
119119
# exclude start time t0
120-
nt = self.timedomain.num_timestamp - 1
120+
nt = self.timedomain.num_timestamps - 1
121121
nx = int(np.ceil(n / nt))
122122
else:
123123
nx = int(
@@ -205,7 +205,7 @@ def random_points(self, n, random="pseudo", criteria=None):
205205
tx = tx[:n]
206206
return tx
207207
elif self.timedomain.timestamps is not None:
208-
nt = self.timedomain.num_timestamp - 1
208+
nt = self.timedomain.num_timestamps - 1
209209
t = self.timedomain.timestamps[1:]
210210
nx = int(np.ceil(n / nt))
211211

@@ -402,7 +402,7 @@ def random_boundary_points(self, n, random="pseudo", criteria=None):
402402
return t_x
403403
elif self.timedomain.timestamps is not None:
404404
# exclude start time t0
405-
nt = self.timedomain.num_timestamp - 1
405+
nt = self.timedomain.num_timestamps - 1
406406
t = self.timedomain.timestamps[1:]
407407
nx = int(np.ceil(n / nt))
408408

ppsci/validate/geo_validator.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,19 @@ def __init__(
9191
self.output_keys = list(label_dict.keys())
9292

9393
nx = dataloader_cfg["total_size"]
94-
self.num_timestamp = 1
94+
self.num_timestamps = 1
9595
# TODO(sensen): simplify code below
9696
if isinstance(geom, geometry.TimeXGeometry):
97-
if geom.timedomain.num_timestamp is not None:
97+
if geom.timedomain.num_timestamps is not None:
9898
if with_initial:
9999
# include t0
100-
self.num_timestamp = geom.timedomain.num_timestamp
100+
self.num_timestamps = geom.timedomain.num_timestamps
101101
assert (
102-
nx % self.num_timestamp == 0
103-
), f"{nx} % {self.num_timestamp} != 0"
104-
nx //= self.num_timestamp
102+
nx % self.num_timestamps == 0
103+
), f"{nx} % {self.num_timestamps} != 0"
104+
nx //= self.num_timestamps
105105
input = geom.sample_interior(
106-
nx * (geom.timedomain.num_timestamp - 1),
106+
nx * (geom.timedomain.num_timestamps - 1),
107107
random,
108108
criteria,
109109
evenly,
@@ -114,13 +114,13 @@ def __init__(
114114
}
115115
else:
116116
# exclude t0
117-
self.num_timestamp = geom.timedomain.num_timestamp - 1
117+
self.num_timestamps = geom.timedomain.num_timestamps - 1
118118
assert (
119-
nx % self.num_timestamp == 0
120-
), f"{nx} % {self.num_timestamp} != 0"
121-
nx //= self.num_timestamp
119+
nx % self.num_timestamps == 0
120+
), f"{nx} % {self.num_timestamps} != 0"
121+
nx //= self.num_timestamps
122122
input = geom.sample_interior(
123-
nx * (geom.timedomain.num_timestamp - 1),
123+
nx * (geom.timedomain.num_timestamps - 1),
124124
random,
125125
criteria,
126126
evenly,

ppsci/visualize/plot.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -69,21 +69,21 @@
6969
]
7070

7171

72-
def _save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamp=1):
72+
def _save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamps=1):
7373
"""Save plot from given 1D data.
7474
7575
Args:
7676
filename (str): Filename.
7777
coord (np.ndarray): Coordinate array.
7878
value (Dict[str, np.ndarray]): Dict of value array.
7979
value_keys (Tuple[str, ...]): Value keys.
80-
num_timestamp (int, optional): Number of timestamps coord/value contains. Defaults to 1.
80+
num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
8181
"""
82-
fig, a = plt.subplots(len(value_keys), num_timestamp, squeeze=False)
82+
fig, a = plt.subplots(len(value_keys), num_timestamps, squeeze=False)
8383
fig.subplots_adjust(hspace=0.8)
8484

85-
len_ts = len(coord) // num_timestamp
86-
for t in range(num_timestamp):
85+
len_ts = len(coord) // num_timestamps
86+
for t in range(num_timestamps):
8787
st = t * len_ts
8888
ed = (t + 1) * len_ts
8989
coord_t = coord[st:ed]
@@ -96,29 +96,29 @@ def _save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamp=1
9696
color=cnames[i],
9797
label=key,
9898
)
99-
if num_timestamp > 1:
99+
if num_timestamps > 1:
100100
a[i][t].set_title(f"{key}(t={t})")
101101
else:
102102
a[i][t].set_title(f"{key}")
103103
a[i][t].grid()
104104
a[i][t].legend()
105105

106-
if num_timestamp == 1:
106+
if num_timestamps == 1:
107107
fig.savefig(filename, dpi=300)
108108
else:
109109
fig.savefig(f"{filename}_{t}", dpi=300)
110110

111-
if num_timestamp == 1:
111+
if num_timestamps == 1:
112112
logger.info(f"1D result is saved to {filename}.png")
113113
else:
114114
logger.info(
115115
f"1D result is saved to {filename}_0.png"
116-
f" ~ {filename}_{num_timestamp - 1}.png"
116+
f" ~ {filename}_{num_timestamps - 1}.png"
117117
)
118118

119119

120120
def save_plot_from_1d_dict(
121-
filename, data_dict, coord_keys, value_keys, num_timestamp=1
121+
filename, data_dict, coord_keys, value_keys, num_timestamps=1
122122
):
123123
"""Plot dict data as file.
124124
@@ -127,7 +127,7 @@ def save_plot_from_1d_dict(
127127
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
128128
coord_keys (Tuple[str, ...]): Tuple of coord key. such as ("x", "y").
129129
value_keys (Tuple[str, ...]): Tuple of value key. such as ("u", "v").
130-
num_timestamp (int, optional): Number of timestamp in data_dict. Defaults to 1.
130+
num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
131131
"""
132132
space_ndim = len(coord_keys) - int("t" in coord_keys)
133133
if space_ndim not in [1, 2, 3]:
@@ -149,14 +149,14 @@ def save_plot_from_1d_dict(
149149
value = [x for x in value]
150150
value = np.concatenate(value, axis=1)
151151

152-
_save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamp)
152+
_save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamps)
153153

154154

155155
def _save_plot_from_2d_array(
156156
filename: str,
157157
visu_data: Tuple[np.ndarray, ...],
158158
visu_keys: Tuple[str, ...],
159-
num_timestamp: int = 1,
159+
num_timestamps: int = 1,
160160
stride: int = 1,
161161
xticks: Optional[Tuple[float, ...]] = None,
162162
yticks: Optional[Tuple[float, ...]] = None,
@@ -167,7 +167,7 @@ def _save_plot_from_2d_array(
167167
filename (str): Filename.
168168
visu_data (Tuple[np.ndarray, ...]): Data that requires visualization.
169169
visu_keys (Tuple[str, ...]]): Keys for visualizing data. such as ("u", "v").
170-
num_timestamp (int, optional): Number of timestamps coord/value contains. Defaults to 1.
170+
num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
171171
stride (int, optional): The time stride of visualization. Defaults to 1.
172172
xticks (Optional[Tuple[float, ...]]): Tuple of xtick locations. Defaults to None.
173173
yticks (Optional[Tuple[float, ...]]): Tuple of ytick locations. Defaults to None.
@@ -179,10 +179,10 @@ def _save_plot_from_2d_array(
179179

180180
fig, ax = plt.subplots(
181181
len(visu_keys),
182-
num_timestamp,
182+
num_timestamps,
183183
squeeze=False,
184184
sharey=True,
185-
figsize=(num_timestamp, len(visu_keys)),
185+
figsize=(num_timestamps, len(visu_keys)),
186186
)
187187
fig.subplots_adjust(hspace=0.3)
188188
target_flag = any(["target" in key for key in visu_keys])
@@ -191,7 +191,7 @@ def _save_plot_from_2d_array(
191191
c_max = np.amax(data)
192192
c_min = np.amin(data)
193193

194-
for t_idx in range(num_timestamp):
194+
for t_idx in range(num_timestamps):
195195
t = t_idx * stride
196196
ax[i, t_idx].imshow(
197197
data[t, :, :],
@@ -226,7 +226,7 @@ def save_plot_from_2d_dict(
226226
filename: str,
227227
data_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
228228
visu_keys: Tuple[str, ...],
229-
num_timestamp: int = 1,
229+
num_timestamps: int = 1,
230230
stride: int = 1,
231231
xticks: Optional[Tuple[float, ...]] = None,
232232
yticks: Optional[Tuple[float, ...]] = None,
@@ -237,7 +237,7 @@ def save_plot_from_2d_dict(
237237
filename (str): Output filename.
238238
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
239239
visu_keys (Tuple[str, ...]): Keys for visualizing data. such as ("u", "v").
240-
num_timestamp (int, optional): Number of timestamp in data_dict. Defaults to 1.
240+
num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
241241
stride (int, optional): The time stride of visualization. Defaults to 1.
242242
xticks (Optional[Tuple[float,...]]): The list of xtick locations. Defaults to None.
243243
yticks (Optional[Tuple[float,...]]): The list of ytick locations. Defaults to None.
@@ -246,7 +246,7 @@ def save_plot_from_2d_dict(
246246
if isinstance(visu_data[0], paddle.Tensor):
247247
visu_data = [x.numpy() for x in visu_data]
248248
_save_plot_from_2d_array(
249-
filename, visu_data, visu_keys, num_timestamp, stride, xticks, yticks
249+
filename, visu_data, visu_keys, num_timestamps, stride, xticks, yticks
250250
)
251251

252252

@@ -308,21 +308,21 @@ def _save_plot_from_3d_array(
308308
filename: str,
309309
visu_data: Tuple[np.ndarray, ...],
310310
visu_keys: Tuple[str, ...],
311-
num_timestamp: int = 1,
311+
num_timestamps: int = 1,
312312
):
313313
"""Save plot from given 3D data.
314314
315315
Args:
316316
filename (str): Filename.
317317
visu_data (Tuple[np.ndarray, ...]): Data that requires visualization.
318318
visu_keys (Tuple[str, ...]]): Keys for visualizing data. such as ("u", "v").
319-
num_timestamp (int, optional): Number of timestamps coord/value contains. Defaults to 1.
319+
num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
320320
"""
321321

322322
fig = plt.figure(figsize=(10, 10))
323-
len_ts = len(visu_data[0]) // num_timestamp
324-
for t in range(num_timestamp):
325-
ax = fig.add_subplot(1, num_timestamp, t + 1, projection="3d")
323+
len_ts = len(visu_data[0]) // num_timestamps
324+
for t in range(num_timestamps):
325+
ax = fig.add_subplot(1, num_timestamps, t + 1, projection="3d")
326326
st = t * len_ts
327327
ed = (t + 1) * len_ts
328328
visu_data_t = [data[st:ed] for data in visu_data]
@@ -343,40 +343,40 @@ def _save_plot_from_3d_array(
343343
loc="upper right",
344344
framealpha=0.95,
345345
)
346-
if num_timestamp == 1:
346+
if num_timestamps == 1:
347347
fig.savefig(filename, dpi=300)
348348
else:
349349
fig.savefig(f"{filename}_{t}", dpi=300)
350350

351-
if num_timestamp == 1:
351+
if num_timestamps == 1:
352352
logger.info(f"3D result is saved to {filename}.png")
353353
else:
354354
logger.info(
355355
f"3D result is saved to {filename}_0.png"
356-
f" ~ {filename}_{num_timestamp - 1}.png"
356+
f" ~ {filename}_{num_timestamps - 1}.png"
357357
)
358358

359359

360360
def save_plot_from_3d_dict(
361361
filename: str,
362362
data_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
363363
visu_keys: Tuple[str, ...],
364-
num_timestamp: int = 1,
364+
num_timestamps: int = 1,
365365
):
366366
"""Plot dict data as file.
367367
368368
Args:
369369
filename (str): Output filename.
370370
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
371371
visu_keys (Tuple[str, ...]): Keys for visualizing data. such as ("u", "v").
372-
num_timestamp (int, optional): Number of timestamp in data_dict. Defaults to 1.
372+
num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
373373
"""
374374

375375
visu_data = [data_dict[k] for k in visu_keys]
376376
if isinstance(visu_data[0], paddle.Tensor):
377377
visu_data = [x.numpy() for x in visu_data]
378378

379-
_save_plot_from_3d_array(filename, visu_data, visu_keys, num_timestamp)
379+
_save_plot_from_3d_array(filename, visu_data, visu_keys, num_timestamps)
380380

381381

382382
def _save_plot_weather_from_array(

0 commit comments

Comments
 (0)