Skip to content

Commit a0f91b1

Browse files
change constants in the fourcastnet examples to uppercase
1 parent 0e093c5 commit a0f91b1

File tree

4 files changed

+148
-149
lines changed

4 files changed

+148
-149
lines changed

examples/fourcastnet/sample_data.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -50,51 +50,51 @@ def sample_data_epoch(epoch: int):
5050
# initialize logger
5151
logger.init_logger("ppsci")
5252
# set dataset path and save path
53-
train_file_path = "./datasets/era5/train"
54-
precip_file_path = None
55-
data_mean_path = "./datasets/era5/stat/global_means.npy"
56-
data_std_path = "./datasets/era5/stat/global_stds.npy"
57-
tmp_save_path = "./datasets/era5/train_split_rank0/epoch_tmp"
53+
TRAIN_FILE_PATH = "./datasets/era5/train"
54+
PRECIP_FILE_PATH = None
55+
DATA_MEAN_PATH = "./datasets/era5/stat/global_means.npy"
56+
DATA_STD_PATH = "./datasets/era5/stat/global_stds.npy"
57+
TMP_SAVE_PATH = "./datasets/era5/train_split_rank0/epoch_tmp"
5858
save_path = f"./datasets/era5/train_split_rank0/epoch_{epoch}"
5959
# set hyper-parameters
60-
input_keys = ["input"]
61-
output_keys = ["output"]
62-
img_h, img_w = 720, 1440
60+
input_keys = ("input",)
61+
output_keys = ("output",)
62+
IMG_H, IMG_W = 720, 1440
6363
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
6464
# The variable name is 'u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z000',
6565
# 'u850', 'v850', 'z850', 'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv'.
6666
# You can obtain detailed information about each variable from
6767
# https://cds.climate.copernicus.eu/cdsapp#!/search?text=era5&type=dataset
68-
vars_channel = [i for i in range(20)]
69-
num_trainer = 1
70-
rank = 0
71-
processes = 16
68+
VARS_CHANNEL = list(range(20))
69+
NUM_TRAINER = 1
70+
RANK = 0
71+
PROCESSES = 16
7272

73-
if len(glob.glob(tmp_save_path + "/*.h5")):
73+
if len(glob.glob(TMP_SAVE_PATH + "/*.h5")):
7474
raise FileExistsError(
75-
f"tmp_save_path({tmp_save_path}) is not an empty folder, please specify an empty folder."
75+
f"TMP_SAVE_PATH({TMP_SAVE_PATH}) is not an empty folder, please specify an empty folder."
7676
)
7777
if len(glob.glob(save_path + "/*.h5")):
7878
raise FileExistsError(
7979
f"save_path({save_path}) is not an empty folder, please specify an empty folder."
8080
)
81-
os.makedirs(tmp_save_path, exist_ok=True)
81+
os.makedirs(TMP_SAVE_PATH, exist_ok=True)
8282

8383
data_mean, data_std = fourcast_utils.get_mean_std(
84-
data_mean_path, data_std_path, vars_channel
84+
DATA_MEAN_PATH, DATA_STD_PATH, VARS_CHANNEL
8585
)
8686
transforms = [
8787
{"SqueezeData": {}},
88-
{"CropData": {"xmin": (0, 0), "xmax": (img_h, img_w)}},
88+
{"CropData": {"xmin": (0, 0), "xmax": (IMG_H, IMG_W)}},
8989
{"Normalize": {"mean": data_mean, "std": data_std}},
9090
]
9191
dataset_cfg = {
9292
"name": "ERA5Dataset",
93-
"file_path": train_file_path,
93+
"file_path": TRAIN_FILE_PATH,
9494
"input_keys": input_keys,
9595
"label_keys": output_keys,
96-
"precip_file_path": precip_file_path,
97-
"vars_channel": vars_channel,
96+
"PRECIP_FILE_PATH": PRECIP_FILE_PATH,
97+
"vars_channel": VARS_CHANNEL,
9898
"transforms": transforms,
9999
}
100100
dataset = ppsci.data.dataset.build_dataset(dataset_cfg)
@@ -103,35 +103,35 @@ def sample_data_epoch(epoch: int):
103103
dataset=dataset,
104104
batch_size=1,
105105
shuffle=False,
106-
num_replicas=num_trainer,
107-
rank=rank,
106+
num_replicas=NUM_TRAINER,
107+
rank=RANK,
108108
)
109109
batch_sampler.set_epoch(epoch)
110110
batch_idxs = []
111111
for data in tqdm(batch_sampler):
112112
batch_idxs += data
113113

114-
pool = Pool(processes=processes)
115-
for st in range(0, len(batch_idxs), len(batch_idxs) // (processes - 1)):
116-
end = st + len(batch_idxs) // (processes - 1)
114+
pool = Pool(processes=PROCESSES)
115+
for st in range(0, len(batch_idxs), len(batch_idxs) // (PROCESSES - 1)):
116+
end = st + len(batch_idxs) // (PROCESSES - 1)
117117
result = pool.apply_async(
118-
sample_func, (dataset_cfg, tmp_save_path, batch_idxs[st:end])
118+
sample_func, (dataset_cfg, TMP_SAVE_PATH, batch_idxs[st:end])
119119
)
120120
pool.close()
121121
pool.join()
122122
if result.successful():
123123
logger.info("successful")
124-
shutil.move(tmp_save_path, save_path)
125-
logger.info(f"move {tmp_save_path} to {save_path}")
124+
shutil.move(TMP_SAVE_PATH, save_path)
125+
logger.info(f"move {TMP_SAVE_PATH} to {save_path}")
126126

127127

128128
def main():
129-
epoch = 0
130-
sample_data_epoch(epoch)
129+
EPOCHS = 0
130+
sample_data_epoch(EPOCHS)
131131

132132
# if you want to sample every 5 epochs, you can use the following code
133-
# epoch = 150
134-
# for i in range(0, epoch, 5):
133+
# EPOCHS = 150
134+
# for epoch in range(0, EPOCHS, 5):
135135
# sample_data_epoch(epoch)
136136

137137

examples/fourcastnet/train_finetune.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -62,38 +62,38 @@ def get_vis_datas(
6262
dist.init_parallel_env()
6363

6464
# set dataset path
65-
train_file_path = "./datasets/era5/train"
66-
valid_file_path = "./datasets/era5/test"
67-
test_file_path = "./datasets/era5/out_of_sample/2018.h5"
68-
data_mean_path = "./datasets/era5/stat/global_means.npy"
69-
data_std_path = "./datasets/era5/stat/global_stds.npy"
70-
data_time_mean_path = "./datasets/era5/stat/time_means.npy"
65+
TRAIN_FILE_PATH = "./datasets/era5/train"
66+
VALID_FILE_PATH = "./datasets/era5/test"
67+
TEST_FILE_PATH = "./datasets/era5/out_of_sample/2018.h5"
68+
DATA_MEAN_PATH = "./datasets/era5/stat/global_means.npy"
69+
DATA_STD_PATH = "./datasets/era5/stat/global_stds.npy"
70+
DATA_TIME_MEAN_PATH = "./datasets/era5/stat/time_means.npy"
7171

7272
# set training hyper-parameters
7373
num_timestamps = 2
7474
input_keys = ("input",)
7575
output_keys = tuple([f"output_{i}" for i in range(num_timestamps)])
76-
img_h, img_w = 720, 1440
77-
epochs = 50 if not args.epochs else args.epochs
76+
IMG_H, IMG_W = 720, 1440
77+
EPOCHS = 50 if not args.epochs else args.epochs
7878
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
7979
# The variable name is 'u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z000',
8080
# 'u850', 'v850', 'z850', 'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv'.
8181
# You can obtain detailed information about each variable from
8282
# https://cds.climate.copernicus.eu/cdsapp#!/search?text=era5&type=dataset
83-
vars_channel = list(range(20))
83+
VARS_CHANNEL = list(range(20))
8484
# set output directory
85-
output_dir = (
85+
OUTPUT_DIR = (
8686
"./output/fourcastnet/finetune" if not args.output_dir else args.output_dir
8787
)
88-
pretrained_model_path = "./output/fourcastnet/pretrain/checkpoints/latest"
88+
PRETRAINED_MODEL_PATH = "./output/fourcastnet/pretrain/checkpoints/latest"
8989
# initialize logger
90-
logger.init_logger("ppsci", f"{output_dir}/train.log", "info")
90+
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info")
9191

9292
data_mean, data_std = fourcast_utils.get_mean_std(
93-
data_mean_path, data_std_path, vars_channel
93+
DATA_MEAN_PATH, DATA_STD_PATH, VARS_CHANNEL
9494
)
9595
data_time_mean = fourcast_utils.get_time_mean(
96-
data_time_mean_path, img_h, img_w, vars_channel
96+
DATA_TIME_MEAN_PATH, IMG_H, IMG_W, VARS_CHANNEL
9797
)
9898
data_time_mean_normalize = np.expand_dims(
9999
(data_time_mean[0] - data_mean) / data_std, 0
@@ -102,17 +102,17 @@ def get_vis_datas(
102102
# set train transforms
103103
transforms = [
104104
{"SqueezeData": {}},
105-
{"CropData": {"xmin": (0, 0), "xmax": (img_h, img_w)}},
105+
{"CropData": {"xmin": (0, 0), "xmax": (IMG_H, IMG_W)}},
106106
{"Normalize": {"mean": data_mean, "std": data_std}},
107107
]
108108
# set train dataloader config
109109
train_dataloader_cfg = {
110110
"dataset": {
111111
"name": "ERA5Dataset",
112-
"file_path": train_file_path,
112+
"file_path": TRAIN_FILE_PATH,
113113
"input_keys": input_keys,
114114
"label_keys": output_keys,
115-
"vars_channel": vars_channel,
115+
"vars_channel": VARS_CHANNEL,
116116
"num_label_timestamps": num_timestamps,
117117
"transforms": transforms,
118118
},
@@ -133,16 +133,16 @@ def get_vis_datas(
133133
constraint = {sup_constraint.name: sup_constraint}
134134

135135
# set iters_per_epoch by dataloader length
136-
iters_per_epoch = len(sup_constraint.data_loader)
136+
ITERS_PER_EPOCH = len(sup_constraint.data_loader)
137137

138138
# set eval dataloader config
139139
eval_dataloader_cfg = {
140140
"dataset": {
141141
"name": "ERA5Dataset",
142-
"file_path": valid_file_path,
142+
"file_path": VALID_FILE_PATH,
143143
"input_keys": input_keys,
144144
"label_keys": output_keys,
145-
"vars_channel": vars_channel,
145+
"vars_channel": VARS_CHANNEL,
146146
"transforms": transforms,
147147
"num_label_timestamps": num_timestamps,
148148
"training": False,
@@ -159,13 +159,13 @@ def get_vis_datas(
159159
metric = {
160160
"MAE": ppsci.metric.MAE(keep_batch=True),
161161
"LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
162-
num_lat=img_h,
162+
num_lat=IMG_H,
163163
std=data_std,
164164
keep_batch=True,
165165
variable_dict={"u10": 0, "v10": 1},
166166
),
167167
"LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
168-
num_lat=img_h,
168+
num_lat=IMG_H,
169169
mean=data_time_mean_normalize,
170170
keep_batch=True,
171171
variable_dict={"u10": 0, "v10": 1},
@@ -186,8 +186,8 @@ def get_vis_datas(
186186

187187
# init optimizer and lr scheduler
188188
lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(
189-
epochs,
190-
iters_per_epoch,
189+
EPOCHS,
190+
ITERS_PER_EPOCH,
191191
5e-4,
192192
by_epoch=True,
193193
)()
@@ -197,15 +197,15 @@ def get_vis_datas(
197197
solver = ppsci.solver.Solver(
198198
model,
199199
constraint,
200-
output_dir,
200+
OUTPUT_DIR,
201201
optimizer,
202202
lr_scheduler,
203-
epochs,
204-
iters_per_epoch,
203+
EPOCHS,
204+
ITERS_PER_EPOCH,
205205
eval_during_train=True,
206206
log_freq=1,
207207
validator=validator,
208-
pretrained_model_path=pretrained_model_path,
208+
pretrained_model_path=PRETRAINED_MODEL_PATH,
209209
compute_metric_by_batch=True,
210210
eval_with_no_grad=True,
211211
)
@@ -224,7 +224,7 @@ def get_vis_datas(
224224
# update eval dataloader config
225225
eval_dataloader_cfg["dataset"].update(
226226
{
227-
"file_path": test_file_path,
227+
"file_path": TEST_FILE_PATH,
228228
"label_keys": output_keys,
229229
"num_label_timestamps": num_timestamps,
230230
"stride": 8,
@@ -241,13 +241,13 @@ def get_vis_datas(
241241
validator = {sup_validator.name: sup_validator}
242242

243243
# set visualizer datas
244-
date_strings = ("2018-09-08 00:00:00",)
244+
DATE_STRINGS = ("2018-09-08 00:00:00",)
245245
vis_datas = get_vis_datas(
246-
test_file_path,
247-
date_strings,
246+
TEST_FILE_PATH,
247+
DATE_STRINGS,
248248
num_timestamps,
249-
vars_channel,
250-
img_h,
249+
VARS_CHANNEL,
250+
IMG_H,
251251
data_mean,
252252
data_std,
253253
)
@@ -288,14 +288,14 @@ def output_wind_func(d, var_name, data_mean, data_std):
288288
}
289289

290290
# directly evaluate pretrained model
291-
logger.init_logger("ppsci", f"{output_dir}/eval.log", "info")
291+
logger.init_logger("ppsci", f"{OUTPUT_DIR}/eval.log", "info")
292292
solver = ppsci.solver.Solver(
293293
model,
294-
output_dir=output_dir,
294+
output_dir=OUTPUT_DIR,
295295
log_freq=1,
296296
validator=validator,
297297
visualizer=visualizer,
298-
pretrained_model_path=f"{output_dir}/checkpoints/latest",
298+
pretrained_model_path=f"{OUTPUT_DIR}/checkpoints/latest",
299299
compute_metric_by_batch=True,
300300
eval_with_no_grad=True,
301301
)

0 commit comments

Comments
 (0)