Skip to content

Commit 6c63cbe

Browse files
refine code according to comments
1 parent a0f91b1 commit 6c63cbe

File tree

15 files changed

+130
-142
lines changed

15 files changed

+130
-142
lines changed

examples/fourcastnet/sample_data.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
from typing import Tuple
2222

2323
import h5py
24-
import numpy as np
25-
from paddle.io import DistributedBatchSampler
24+
from paddle import io
2625
from tqdm import tqdm
2726

2827
import examples.fourcastnet.utils as fourcast_utils
@@ -35,7 +34,7 @@ def sample_func(
3534
):
3635
dataset = ppsci.data.dataset.build_dataset(dataset_cfg)
3736
for idx in tqdm(batch_idxs):
38-
input_dict, label_dict, weight_dict = dataset.getitem(idx)
37+
input_dict, label_dict, weight_dict = dataset[idx]
3938
fdest = h5py.File(f"{save_path}/{idx:0>8d}.h5", "w")
4039
for key, value in input_dict.items():
4140
fdest.create_dataset(f"input_dict/{key}", data=value, dtype="f")
@@ -99,7 +98,7 @@ def sample_data_epoch(epoch: int):
9998
}
10099
dataset = ppsci.data.dataset.build_dataset(dataset_cfg)
101100

102-
batch_sampler = DistributedBatchSampler(
101+
batch_sampler = io.DistributedBatchSampler(
103102
dataset=dataset,
104103
batch_size=1,
105104
shuffle=False,

examples/fourcastnet/train_finetune.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from functools import partial
15+
import functools
1616
from typing import Tuple
1717

1818
import h5py
@@ -70,9 +70,9 @@ def get_vis_datas(
7070
DATA_TIME_MEAN_PATH = "./datasets/era5/stat/time_means.npy"
7171

7272
# set training hyper-parameters
73-
num_timestamps = 2
73+
NUM_TIMESTAMPS = 2
7474
input_keys = ("input",)
75-
output_keys = tuple([f"output_{i}" for i in range(num_timestamps)])
75+
output_keys = tuple([f"output_{i}" for i in range(NUM_TIMESTAMPS)])
7676
IMG_H, IMG_W = 720, 1440
7777
EPOCHS = 50 if not args.epochs else args.epochs
7878
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
@@ -113,7 +113,7 @@ def get_vis_datas(
113113
"input_keys": input_keys,
114114
"label_keys": output_keys,
115115
"vars_channel": VARS_CHANNEL,
116-
"num_label_timestamps": num_timestamps,
116+
"num_label_timestamps": NUM_TIMESTAMPS,
117117
"transforms": transforms,
118118
},
119119
"sampler": {
@@ -144,7 +144,7 @@ def get_vis_datas(
144144
"label_keys": output_keys,
145145
"vars_channel": VARS_CHANNEL,
146146
"transforms": transforms,
147-
"num_label_timestamps": num_timestamps,
147+
"num_label_timestamps": NUM_TIMESTAMPS,
148148
"training": False,
149149
},
150150
"sampler": {
@@ -182,7 +182,7 @@ def get_vis_datas(
182182
validator = {sup_validator.name: sup_validator}
183183

184184
# set model
185-
model = ppsci.arch.AFNONet(input_keys, output_keys, num_timestamps=num_timestamps)
185+
model = ppsci.arch.AFNONet(input_keys, output_keys, num_timestamps=NUM_TIMESTAMPS)
186186

187187
# init optimizer and lr scheduler
188188
lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(
@@ -203,7 +203,6 @@ def get_vis_datas(
203203
EPOCHS,
204204
ITERS_PER_EPOCH,
205205
eval_during_train=True,
206-
log_freq=1,
207206
validator=validator,
208207
pretrained_model_path=PRETRAINED_MODEL_PATH,
209208
compute_metric_by_batch=True,
@@ -215,18 +214,18 @@ def get_vis_datas(
215214
solver.eval()
216215

217216
# set testing hyper-parameters
218-
num_timestamps = 32
219-
output_keys = tuple([f"output_{i}" for i in range(num_timestamps)])
217+
NUM_TIMESTAMPS = 32
218+
output_keys = tuple([f"output_{i}" for i in range(NUM_TIMESTAMPS)])
220219

221220
# set model for testing
222-
model = ppsci.arch.AFNONet(input_keys, output_keys, num_timestamps=num_timestamps)
221+
model = ppsci.arch.AFNONet(input_keys, output_keys, num_timestamps=NUM_TIMESTAMPS)
223222

224223
# update eval dataloader config
225224
eval_dataloader_cfg["dataset"].update(
226225
{
227226
"file_path": TEST_FILE_PATH,
228227
"label_keys": output_keys,
229-
"num_label_timestamps": num_timestamps,
228+
"num_label_timestamps": NUM_TIMESTAMPS,
230229
"stride": 8,
231230
}
232231
)
@@ -245,7 +244,7 @@ def get_vis_datas(
245244
vis_datas = get_vis_datas(
246245
TEST_FILE_PATH,
247246
DATE_STRINGS,
248-
num_timestamps,
247+
NUM_TIMESTAMPS,
249248
VARS_CHANNEL,
250249
IMG_H,
251250
data_mean,
@@ -257,16 +256,16 @@ def output_wind_func(d, var_name, data_mean, data_std):
257256
wind_data = []
258257
for i in range(output.shape[0]):
259258
wind_data.append((output[i][0] ** 2 + output[i][1] ** 2) ** 0.5)
260-
return paddle.to_tensor(wind_data)
259+
return paddle.to_tensor(wind_data, paddle.get_default_dtype())
261260

262261
vis_output_expr = {}
263-
for i in range(num_timestamps):
262+
for i in range(NUM_TIMESTAMPS):
264263
hour = (i + 1) * 6
265-
vis_output_expr[f"output_{hour}h"] = partial(
264+
vis_output_expr[f"output_{hour}h"] = functools.partial(
266265
output_wind_func,
267266
var_name=f"output_{i}",
268-
data_mean=paddle.to_tensor(data_mean),
269-
data_std=paddle.to_tensor(data_std),
267+
data_mean=paddle.to_tensor(data_mean, paddle.get_default_dtype()),
268+
data_std=paddle.to_tensor(data_std, paddle.get_default_dtype()),
270269
)
271270
vis_output_expr[f"target_{hour}h"] = lambda d, hour=hour: d[f"target_{hour}h"]
272271
# set visualizer
@@ -282,7 +281,7 @@ def output_wind_func(d, var_name, data_mean, data_std):
282281
vmax=25,
283282
colorbar_label="m\s",
284283
batch_size=1,
285-
num_timestamps=num_timestamps,
284+
num_timestamps=NUM_TIMESTAMPS,
286285
prefix="wind",
287286
)
288287
}
@@ -292,7 +291,6 @@ def output_wind_func(d, var_name, data_mean, data_std):
292291
solver = ppsci.solver.Solver(
293292
model,
294293
output_dir=OUTPUT_DIR,
295-
log_freq=1,
296294
validator=validator,
297295
visualizer=visualizer,
298296
pretrained_model_path=f"{OUTPUT_DIR}/checkpoints/latest",

examples/fourcastnet/train_precip.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from functools import partial
15+
import functools
1616
from typing import Tuple
1717

1818
import h5py
@@ -208,7 +208,6 @@ def get_vis_datas(
208208
EPOCHS,
209209
ITERS_PER_EPOCH,
210210
eval_during_train=True,
211-
log_freq=1,
212211
validator=validator,
213212
compute_metric_by_batch=True,
214213
eval_with_no_grad=True,
@@ -219,12 +218,12 @@ def get_vis_datas(
219218
solver.eval()
220219

221220
# set testing hyper-parameters
222-
num_timestamps = 6
223-
output_keys = tuple([f"output_{i}" for i in range(num_timestamps)])
221+
NUM_TIMESTAMPS = 6
222+
output_keys = tuple([f"output_{i}" for i in range(NUM_TIMESTAMPS)])
224223

225224
# set model for testing
226225
model = ppsci.arch.PrecipNet(
227-
input_keys, output_keys, num_timestamps=num_timestamps, wind_model=wind_model
226+
input_keys, output_keys, num_timestamps=NUM_TIMESTAMPS, wind_model=wind_model
228227
)
229228

230229
# update eval dataloader config
@@ -233,7 +232,7 @@ def get_vis_datas(
233232
"file_path": WIND_TEST_FILE_PATH,
234233
"label_keys": output_keys,
235234
"precip_file_path": TEST_FILE_PATH,
236-
"num_label_timestamps": num_timestamps,
235+
"num_label_timestamps": NUM_TIMESTAMPS,
237236
"stride": 8,
238237
}
239238
)
@@ -253,7 +252,7 @@ def get_vis_datas(
253252
WIND_TEST_FILE_PATH,
254253
TEST_FILE_PATH,
255254
DATE_STRINGS,
256-
num_timestamps,
255+
NUM_TIMESTAMPS,
257256
VARS_CHANNEL,
258257
IMG_H,
259258
wind_data_mean,
@@ -265,9 +264,9 @@ def output_precip_func(d, var_name):
265264
return output
266265

267266
visu_output_expr = {}
268-
for i in range(num_timestamps):
267+
for i in range(NUM_TIMESTAMPS):
269268
hour = (i + 1) * 6
270-
visu_output_expr[f"output_{hour}h"] = partial(
269+
visu_output_expr[f"output_{hour}h"] = functools.partial(
271270
output_precip_func,
272271
var_name=f"output_{i}",
273272
)
@@ -288,7 +287,7 @@ def output_precip_func(d, var_name):
288287
colorbar_label="mm",
289288
log_norm=True,
290289
batch_size=1,
291-
num_timestamps=num_timestamps,
290+
num_timestamps=NUM_TIMESTAMPS,
292291
prefix="precip",
293292
)
294293
}
@@ -298,7 +297,6 @@ def output_precip_func(d, var_name):
298297
solver = ppsci.solver.Solver(
299298
model,
300299
output_dir=OUTPUT_DIR,
301-
log_freq=1,
302300
validator=validator,
303301
visualizer=visualizer,
304302
pretrained_model_path=f"{OUTPUT_DIR}/checkpoints/latest",

examples/fourcastnet/train_pretrain.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@
182182
EPOCHS,
183183
ITERS_PER_EPOCH,
184184
eval_during_train=True,
185-
log_freq=1,
186185
validator=validator,
187186
compute_metric_by_batch=True,
188187
eval_with_no_grad=True,
@@ -198,7 +197,6 @@
198197
model,
199198
constraint,
200199
OUTPUT_DIR,
201-
log_freq=1,
202200
validator=validator,
203201
pretrained_model_path=f"{OUTPUT_DIR}/checkpoints/latest",
204202
compute_metric_by_batch=True,

ppsci/arch/afno.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -175,45 +175,31 @@ def __init__(
175175
self.hidden_size_factor = hidden_size_factor
176176
self.scale = scale
177177

178-
w1 = self.scale * paddle.randn(
179-
(
178+
self.w1 = self.create_parameter(
179+
shape=(
180180
2,
181181
self.num_blocks,
182182
self.block_size,
183183
self.block_size * self.hidden_size_factor,
184-
)
185-
)
186-
self.w1 = self.create_parameter(
187-
shape=w1.shape,
188-
dtype=w1.dtype,
189-
default_initializer=nn.initializer.Assign(w1),
190-
)
191-
b1 = self.scale * paddle.randn(
192-
(2, self.num_blocks, self.block_size * self.hidden_size_factor)
184+
),
185+
default_initializer=nn.initializer.Normal(std=self.scale),
193186
)
194187
self.b1 = self.create_parameter(
195-
shape=b1.shape,
196-
dtype=b1.dtype,
197-
default_initializer=nn.initializer.Assign(b1),
188+
shape=(2, self.num_blocks, self.block_size * self.hidden_size_factor),
189+
default_initializer=nn.initializer.Normal(std=self.scale),
198190
)
199-
w2 = self.scale * paddle.randn(
200-
(
191+
self.w2 = self.create_parameter(
192+
shape=(
201193
2,
202194
self.num_blocks,
203195
self.block_size * self.hidden_size_factor,
204196
self.block_size,
205-
)
206-
)
207-
self.w2 = self.create_parameter(
208-
shape=w2.shape,
209-
dtype=w2.dtype,
210-
default_initializer=paddle.nn.initializer.Assign(w2),
197+
),
198+
default_initializer=nn.initializer.Normal(std=self.scale),
211199
)
212-
b2 = self.scale * paddle.randn((2, self.num_blocks, self.block_size))
213200
self.b2 = self.create_parameter(
214-
shape=b2.shape,
215-
dtype=b2.dtype,
216-
default_initializer=paddle.nn.initializer.Assign(b2),
201+
shape=(2, self.num_blocks, self.block_size),
202+
default_initializer=nn.initializer.Normal(std=self.scale),
217203
)
218204

219205
def forward(self, x):

ppsci/data/dataset/era5_dataset.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def __init__(
8383

8484
self.files = self.read_data(file_path)
8585
self.n_years = len(self.files)
86-
self.n_samples_per_year = self.files[0].shape[0]
87-
self.n_samples_total = self.n_years * self.n_samples_per_year
86+
self.num_samples_per_year = self.files[0].shape[0]
87+
self.num_samples = self.n_years * self.num_samples_per_year
8888
if self.precip_file_path is not None:
8989
self.precip_files = self.read_data(precip_file_path, "tp")
9090

@@ -98,17 +98,17 @@ def read_data(self, path: str, var="fields"):
9898
return files
9999

100100
def __len__(self):
101-
return self.n_samples_total // self.stride
101+
return self.num_samples // self.stride
102102

103103
def __getitem__(self, global_idx):
104104
global_idx *= self.stride
105-
year_idx = global_idx // self.n_samples_per_year
106-
local_idx = global_idx % self.n_samples_per_year
107-
step = 0 if local_idx >= self.n_samples_per_year - 1 else 1
105+
year_idx = global_idx // self.num_samples_per_year
106+
local_idx = global_idx % self.num_samples_per_year
107+
step = 0 if local_idx >= self.num_samples_per_year - 1 else 1
108108

109109
if self.num_label_timestamps > 1:
110-
if local_idx >= self.n_samples_per_year - self.num_label_timestamps:
111-
local_idx = self.n_samples_per_year - self.num_label_timestamps - 1
110+
if local_idx >= self.num_samples_per_year - self.num_label_timestamps:
111+
local_idx = self.num_samples_per_year - self.num_label_timestamps - 1
112112

113113
input_file = self.files[year_idx]
114114
label_file = (
@@ -118,7 +118,7 @@ def __getitem__(self, global_idx):
118118
)
119119
if self.precip_file_path is not None and year_idx == 0 and self.training:
120120
# first year has 2 missing samples in precip (they are first two time points)
121-
lim = self.n_samples_per_year - 2
121+
lim = self.num_samples_per_year - 2
122122
local_idx = local_idx % lim
123123
step = 0 if local_idx >= lim - 1 else 1
124124
input_idx = local_idx + 2
@@ -154,9 +154,6 @@ def __getitem__(self, global_idx):
154154

155155
return input_item, label_item, weight_item
156156

157-
def getitem(self, global_idx):
158-
return self.__getitem__(global_idx)
159-
160157

161158
class ERA5SampledDataset(io.Dataset):
162159
"""Class for ERA5 sampled dataset.
@@ -198,7 +195,7 @@ def __init__(
198195
self.transforms = transforms
199196

200197
self.files = self.read_data(file_path)
201-
self.n_samples_total = len(self.files)
198+
self.num_samples = len(self.files)
202199

203200
def read_data(self, path: str):
204201
paths = glob.glob(path + "/*.h5")
@@ -210,7 +207,7 @@ def read_data(self, path: str):
210207
return files
211208

212209
def __len__(self):
213-
return self.n_samples_total
210+
return self.num_samples
214211

215212
def __getitem__(self, global_idx):
216213
_file = self.files[global_idx]

0 commit comments

Comments
 (0)