@@ -62,38 +62,38 @@ def get_vis_datas(
62
62
dist .init_parallel_env ()
63
63
64
64
# set dataset path
65
- train_file_path = "./datasets/era5/train"
66
- valid_file_path = "./datasets/era5/test"
67
- test_file_path = "./datasets/era5/out_of_sample/2018.h5"
68
- data_mean_path = "./datasets/era5/stat/global_means.npy"
69
- data_std_path = "./datasets/era5/stat/global_stds.npy"
70
- data_time_mean_path = "./datasets/era5/stat/time_means.npy"
65
+ TRAIN_FILE_PATH = "./datasets/era5/train"
66
+ VALID_FILE_PATH = "./datasets/era5/test"
67
+ TEST_FILE_PATH = "./datasets/era5/out_of_sample/2018.h5"
68
+ DATA_MEAN_PATH = "./datasets/era5/stat/global_means.npy"
69
+ DATA_STD_PATH = "./datasets/era5/stat/global_stds.npy"
70
+ DATA_TIME_MEAN_PATH = "./datasets/era5/stat/time_means.npy"
71
71
72
72
# set training hyper-parameters
73
73
num_timestamps = 2
74
74
input_keys = ("input" ,)
75
75
output_keys = tuple ([f"output_{ i } " for i in range (num_timestamps )])
76
- img_h , img_w = 720 , 1440
77
- epochs = 50 if not args .epochs else args .epochs
76
+ IMG_H , IMG_W = 720 , 1440
77
+ EPOCHS = 50 if not args .epochs else args .epochs
78
78
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
79
79
# The variable name is 'u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z000',
80
80
# 'u850', 'v850', 'z850', 'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv'.
81
81
# You can obtain detailed information about each variable from
82
82
# https://cds.climate.copernicus.eu/cdsapp#!/search?text=era5&type=dataset
83
- vars_channel = list (range (20 ))
83
+ VARS_CHANNEL = list (range (20 ))
84
84
# set output directory
85
- output_dir = (
85
+ OUTPUT_DIR = (
86
86
"./output/fourcastnet/finetune" if not args .output_dir else args .output_dir
87
87
)
88
- pretrained_model_path = "./output/fourcastnet/pretrain/checkpoints/latest"
88
+ PRETRAINED_MODEL_PATH = "./output/fourcastnet/pretrain/checkpoints/latest"
89
89
# initialize logger
90
- logger .init_logger ("ppsci" , f"{ output_dir } /train.log" , "info" )
90
+ logger .init_logger ("ppsci" , f"{ OUTPUT_DIR } /train.log" , "info" )
91
91
92
92
data_mean , data_std = fourcast_utils .get_mean_std (
93
- data_mean_path , data_std_path , vars_channel
93
+ DATA_MEAN_PATH , DATA_STD_PATH , VARS_CHANNEL
94
94
)
95
95
data_time_mean = fourcast_utils .get_time_mean (
96
- data_time_mean_path , img_h , img_w , vars_channel
96
+ DATA_TIME_MEAN_PATH , IMG_H , IMG_W , VARS_CHANNEL
97
97
)
98
98
data_time_mean_normalize = np .expand_dims (
99
99
(data_time_mean [0 ] - data_mean ) / data_std , 0
@@ -102,17 +102,17 @@ def get_vis_datas(
102
102
# set train transforms
103
103
transforms = [
104
104
{"SqueezeData" : {}},
105
- {"CropData" : {"xmin" : (0 , 0 ), "xmax" : (img_h , img_w )}},
105
+ {"CropData" : {"xmin" : (0 , 0 ), "xmax" : (IMG_H , IMG_W )}},
106
106
{"Normalize" : {"mean" : data_mean , "std" : data_std }},
107
107
]
108
108
# set train dataloader config
109
109
train_dataloader_cfg = {
110
110
"dataset" : {
111
111
"name" : "ERA5Dataset" ,
112
- "file_path" : train_file_path ,
112
+ "file_path" : TRAIN_FILE_PATH ,
113
113
"input_keys" : input_keys ,
114
114
"label_keys" : output_keys ,
115
- "vars_channel" : vars_channel ,
115
+ "vars_channel" : VARS_CHANNEL ,
116
116
"num_label_timestamps" : num_timestamps ,
117
117
"transforms" : transforms ,
118
118
},
@@ -133,16 +133,16 @@ def get_vis_datas(
133
133
constraint = {sup_constraint .name : sup_constraint }
134
134
135
135
# set iters_per_epoch by dataloader length
136
- iters_per_epoch = len (sup_constraint .data_loader )
136
+ ITERS_PER_EPOCH = len (sup_constraint .data_loader )
137
137
138
138
# set eval dataloader config
139
139
eval_dataloader_cfg = {
140
140
"dataset" : {
141
141
"name" : "ERA5Dataset" ,
142
- "file_path" : valid_file_path ,
142
+ "file_path" : VALID_FILE_PATH ,
143
143
"input_keys" : input_keys ,
144
144
"label_keys" : output_keys ,
145
- "vars_channel" : vars_channel ,
145
+ "vars_channel" : VARS_CHANNEL ,
146
146
"transforms" : transforms ,
147
147
"num_label_timestamps" : num_timestamps ,
148
148
"training" : False ,
@@ -159,13 +159,13 @@ def get_vis_datas(
159
159
metric = {
160
160
"MAE" : ppsci .metric .MAE (keep_batch = True ),
161
161
"LatitudeWeightedRMSE" : ppsci .metric .LatitudeWeightedRMSE (
162
- num_lat = img_h ,
162
+ num_lat = IMG_H ,
163
163
std = data_std ,
164
164
keep_batch = True ,
165
165
variable_dict = {"u10" : 0 , "v10" : 1 },
166
166
),
167
167
"LatitudeWeightedACC" : ppsci .metric .LatitudeWeightedACC (
168
- num_lat = img_h ,
168
+ num_lat = IMG_H ,
169
169
mean = data_time_mean_normalize ,
170
170
keep_batch = True ,
171
171
variable_dict = {"u10" : 0 , "v10" : 1 },
@@ -186,8 +186,8 @@ def get_vis_datas(
186
186
187
187
# init optimizer and lr scheduler
188
188
lr_scheduler = ppsci .optimizer .lr_scheduler .Cosine (
189
- epochs ,
190
- iters_per_epoch ,
189
+ EPOCHS ,
190
+ ITERS_PER_EPOCH ,
191
191
5e-4 ,
192
192
by_epoch = True ,
193
193
)()
@@ -197,15 +197,15 @@ def get_vis_datas(
197
197
solver = ppsci .solver .Solver (
198
198
model ,
199
199
constraint ,
200
- output_dir ,
200
+ OUTPUT_DIR ,
201
201
optimizer ,
202
202
lr_scheduler ,
203
- epochs ,
204
- iters_per_epoch ,
203
+ EPOCHS ,
204
+ ITERS_PER_EPOCH ,
205
205
eval_during_train = True ,
206
206
log_freq = 1 ,
207
207
validator = validator ,
208
- pretrained_model_path = pretrained_model_path ,
208
+ pretrained_model_path = PRETRAINED_MODEL_PATH ,
209
209
compute_metric_by_batch = True ,
210
210
eval_with_no_grad = True ,
211
211
)
@@ -224,7 +224,7 @@ def get_vis_datas(
224
224
# update eval dataloader config
225
225
eval_dataloader_cfg ["dataset" ].update (
226
226
{
227
- "file_path" : test_file_path ,
227
+ "file_path" : TEST_FILE_PATH ,
228
228
"label_keys" : output_keys ,
229
229
"num_label_timestamps" : num_timestamps ,
230
230
"stride" : 8 ,
@@ -241,13 +241,13 @@ def get_vis_datas(
241
241
validator = {sup_validator .name : sup_validator }
242
242
243
243
# set visualizer datas
244
- date_strings = ("2018-09-08 00:00:00" ,)
244
+ DATE_STRINGS = ("2018-09-08 00:00:00" ,)
245
245
vis_datas = get_vis_datas (
246
- test_file_path ,
247
- date_strings ,
246
+ TEST_FILE_PATH ,
247
+ DATE_STRINGS ,
248
248
num_timestamps ,
249
- vars_channel ,
250
- img_h ,
249
+ VARS_CHANNEL ,
250
+ IMG_H ,
251
251
data_mean ,
252
252
data_std ,
253
253
)
@@ -288,14 +288,14 @@ def output_wind_func(d, var_name, data_mean, data_std):
288
288
}
289
289
290
290
# directly evaluate pretrained model
291
- logger .init_logger ("ppsci" , f"{ output_dir } /eval.log" , "info" )
291
+ logger .init_logger ("ppsci" , f"{ OUTPUT_DIR } /eval.log" , "info" )
292
292
solver = ppsci .solver .Solver (
293
293
model ,
294
- output_dir = output_dir ,
294
+ output_dir = OUTPUT_DIR ,
295
295
log_freq = 1 ,
296
296
validator = validator ,
297
297
visualizer = visualizer ,
298
- pretrained_model_path = f"{ output_dir } /checkpoints/latest" ,
298
+ pretrained_model_path = f"{ OUTPUT_DIR } /checkpoints/latest" ,
299
299
compute_metric_by_batch = True ,
300
300
eval_with_no_grad = True ,
301
301
)
0 commit comments