Skip to content

Commit 7baf04a

Browse files
1. add train_finetune.py for fourcastnet;
2. add utils.py in fourcastnet example; 3. fix LatitudeWeightedACC and LatitudeWeightedRMSE; 4. fix save_plot_weather_from_dict;
1 parent 1c0221e commit 7baf04a

File tree

9 files changed

+385
-42
lines changed

9 files changed

+385
-42
lines changed

examples/fourcastnet/sample_data.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,11 @@
2525
from paddle.io import DistributedBatchSampler
2626
from tqdm import tqdm
2727

28+
import examples.fourcastnet.utils as fourcast_utils
2829
import ppsci
2930
from ppsci.utils import logger
3031

3132

32-
def get_mean_std(mean_path: str, std_path: str, vars_channel: Tuple[int, ...]):
33-
mean = np.load(mean_path).squeeze(0).astype(np.float32)
34-
mean = mean[vars_channel]
35-
std = np.load(std_path).squeeze(0).astype(np.float32)
36-
std = std[vars_channel]
37-
return mean, std
38-
39-
4033
def sample_func(
4134
dataset_cfg: Dict[str, Any], save_path: str, batch_idxs: Tuple[int, ...]
4235
):
@@ -87,7 +80,9 @@ def sample_data_epoch(epoch: int):
8780
)
8881
os.makedirs(tmp_save_path, exist_ok=True)
8982

90-
data_mean, data_std = get_mean_std(data_mean_path, data_std_path, vars_channel)
83+
data_mean, data_std = fourcast_utils.get_mean_std(
84+
data_mean_path, data_std_path, vars_channel
85+
)
9186
transforms = [
9287
{"SqueezeData": {}},
9388
{"CropData": {"xmin": (0, 0), "xmax": (img_h, img_w)}},
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import partial
16+
from typing import Tuple
17+
18+
import h5py
19+
import numpy as np
20+
import paddle
21+
import paddle.distributed as dist
22+
23+
import examples.fourcastnet.utils as fourcast_utils
24+
import ppsci
25+
from ppsci.utils import config
26+
from ppsci.utils import logger
27+
28+
29+
def get_vis_datas(
30+
file_path: str,
31+
date_strings: Tuple[str, ...],
32+
num_timestamps: int,
33+
vars_channel: Tuple[int, ...],
34+
img_h: int,
35+
data_mean: np.ndarray,
36+
data_std: np.ndarray,
37+
):
38+
_file = h5py.File(file_path, "r")["fields"]
39+
data = []
40+
for date_str in date_strings:
41+
hours_since_jan_01_epoch = fourcast_utils.date_to_hours(date_str)
42+
ic = int(hours_since_jan_01_epoch / 6)
43+
data.append(_file[ic : ic + num_timestamps + 1, vars_channel, 0:img_h])
44+
data = np.asarray(data)
45+
46+
vis_datas = {"input": (data[:, 0] - data_mean) / data_std}
47+
for t in range(num_timestamps):
48+
hour = (t + 1) * 6
49+
data_t = data[:, t + 1]
50+
wind_data = []
51+
for i in range(data_t.shape[0]):
52+
wind_data.append((data_t[i][0] ** 2 + data_t[i][1] ** 2) ** 0.5)
53+
vis_datas[f"target_{hour}h"] = np.asarray(wind_data)
54+
return vis_datas
55+
56+
57+
if __name__ == "__main__":
58+
args = config.parse_args()
59+
# set random seed for reproducibility
60+
ppsci.utils.set_random_seed(1024)
61+
# Initialize distributed environment
62+
dist.init_parallel_env()
63+
64+
# set dataset path
65+
train_file_path = "./datasets/era5/train"
66+
valid_file_path = "./datasets/era5/test"
67+
test_file_path = "./datasets/era5/out_of_sample/2018.h5"
68+
data_mean_path = "./datasets/era5/stat/global_means.npy"
69+
data_std_path = "./datasets/era5/stat/global_stds.npy"
70+
data_time_mean_path = "./datasets/era5/stat/time_means.npy"
71+
72+
# set training hyper-parameters
73+
num_timestamps = 2
74+
input_keys = ("input",)
75+
output_keys = tuple([f"output_{i}" for i in range(num_timestamps)])
76+
img_h, img_w = 720, 1440
77+
epochs = 50 if not args.epochs else args.epochs
78+
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
79+
# The variable name is 'u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z000',
80+
# 'u850', 'v850', 'z850', 'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv'.
81+
# You can obtain detailed information about each variable from
82+
# https://cds.climate.copernicus.eu/cdsapp#!/search?text=era5&type=dataset
83+
vars_channel = list(range(20))
84+
# set output directory
85+
output_dir = (
86+
"./output/fourcastnet/finetune" if not args.output_dir else args.output_dir
87+
)
88+
pretrained_model_path = "./output/fourcastnet/pretrain/checkpoints/latest"
89+
# initialize logger
90+
logger.init_logger("ppsci", f"{output_dir}/train.log", "info")
91+
92+
data_mean, data_std = fourcast_utils.get_mean_std(
93+
data_mean_path, data_std_path, vars_channel
94+
)
95+
data_time_mean = fourcast_utils.get_time_mean(
96+
data_time_mean_path, img_h, img_w, vars_channel
97+
)
98+
data_time_mean_normalize = np.expand_dims(
99+
(data_time_mean[0] - data_mean) / data_std, 0
100+
)
101+
102+
# set train transforms
103+
transforms = [
104+
{"SqueezeData": {}},
105+
{"CropData": {"xmin": (0, 0), "xmax": (img_h, img_w)}},
106+
{"Normalize": {"mean": data_mean, "std": data_std}},
107+
]
108+
# set train dataloader config
109+
train_dataloader_cfg = {
110+
"dataset": {
111+
"name": "ERA5Dataset",
112+
"file_path": train_file_path,
113+
"input_keys": input_keys,
114+
"label_keys": output_keys,
115+
"vars_channel": vars_channel,
116+
"num_label_timestamps": num_timestamps,
117+
"transforms": transforms,
118+
},
119+
"sampler": {
120+
"name": "BatchSampler",
121+
"drop_last": True,
122+
"shuffle": True,
123+
},
124+
"batch_size": 1,
125+
"num_workers": 8,
126+
}
127+
# set constraint
128+
sup_constraint = ppsci.constraint.SupervisedConstraint(
129+
train_dataloader_cfg,
130+
ppsci.loss.L2RelLoss(),
131+
name="Sup",
132+
)
133+
constraint = {sup_constraint.name: sup_constraint}
134+
135+
# set iters_per_epoch by dataloader length
136+
iters_per_epoch = len(sup_constraint.data_loader)
137+
138+
# set eval dataloader config
139+
eval_dataloader_cfg = {
140+
"dataset": {
141+
"name": "ERA5Dataset",
142+
"file_path": valid_file_path,
143+
"input_keys": input_keys,
144+
"label_keys": output_keys,
145+
"vars_channel": vars_channel,
146+
"transforms": transforms,
147+
"num_label_timestamps": num_timestamps,
148+
"training": False,
149+
},
150+
"sampler": {
151+
"name": "BatchSampler",
152+
"drop_last": False,
153+
"shuffle": False,
154+
},
155+
"batch_size": 1,
156+
}
157+
158+
# set metirc
159+
metric = {
160+
"MAE": ppsci.metric.MAE(keep_batch=True),
161+
"LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
162+
num_lat=img_h,
163+
std=data_std,
164+
keep_batch=True,
165+
variable_dict={"u10": 0, "v10": 1},
166+
),
167+
"LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
168+
num_lat=img_h,
169+
mean=data_time_mean_normalize,
170+
keep_batch=True,
171+
variable_dict={"u10": 0, "v10": 1},
172+
),
173+
}
174+
175+
# set validator
176+
sup_validator = ppsci.validate.SupervisedValidator(
177+
eval_dataloader_cfg,
178+
ppsci.loss.L2RelLoss(),
179+
metric=metric,
180+
name="Sup_Validator",
181+
)
182+
validator = {sup_validator.name: sup_validator}
183+
184+
# set model
185+
model = ppsci.arch.AFNONet(input_keys, output_keys, num_timestamps=num_timestamps)
186+
187+
# init optimizer and lr scheduler
188+
lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(
189+
epochs,
190+
iters_per_epoch,
191+
5e-4,
192+
by_epoch=True,
193+
)()
194+
optimizer = ppsci.optimizer.Adam(lr_scheduler)((model,))
195+
196+
# initialize solver
197+
solver = ppsci.solver.Solver(
198+
model,
199+
constraint,
200+
output_dir,
201+
optimizer,
202+
lr_scheduler,
203+
epochs,
204+
iters_per_epoch,
205+
eval_during_train=True,
206+
log_freq=1,
207+
validator=validator,
208+
pretrained_model_path=pretrained_model_path,
209+
compute_metric_by_batch=True,
210+
eval_with_no_grad=True,
211+
)
212+
# train model
213+
solver.train()
214+
# evaluate after finished training
215+
solver.eval()
216+
217+
# set testing hyper-parameters
218+
num_timestamps = 32
219+
output_keys = tuple([f"output_{i}" for i in range(num_timestamps)])
220+
221+
# set model for testing
222+
model = ppsci.arch.AFNONet(input_keys, output_keys, num_timestamps=num_timestamps)
223+
224+
# update eval dataloader config
225+
eval_dataloader_cfg["dataset"].update(
226+
{
227+
"file_path": test_file_path,
228+
"label_keys": output_keys,
229+
"num_label_timestamps": num_timestamps,
230+
"stride": 8,
231+
}
232+
)
233+
234+
# set validator for testing
235+
sup_validator = ppsci.validate.SupervisedValidator(
236+
eval_dataloader_cfg,
237+
ppsci.loss.L2RelLoss(),
238+
metric=metric,
239+
name="Sup_Validator",
240+
)
241+
validator = {sup_validator.name: sup_validator}
242+
243+
# set visualizer datas
244+
date_strings = ("2018-09-08 00:00:00",)
245+
vis_datas = get_vis_datas(
246+
test_file_path,
247+
date_strings,
248+
num_timestamps,
249+
vars_channel,
250+
img_h,
251+
data_mean,
252+
data_std,
253+
)
254+
255+
def output_wind_func(d, var_name, data_mean, data_std):
256+
output = (d[var_name] * data_std) + data_mean
257+
wind_data = []
258+
for i in range(output.shape[0]):
259+
wind_data.append((output[i][0] ** 2 + output[i][1] ** 2) ** 0.5)
260+
return paddle.to_tensor(wind_data)
261+
262+
vis_output_expr = {}
263+
for i in range(num_timestamps):
264+
hour = (i + 1) * 6
265+
vis_output_expr[f"output_{hour}h"] = partial(
266+
output_wind_func,
267+
var_name=f"output_{i}",
268+
data_mean=paddle.to_tensor(data_mean),
269+
data_std=paddle.to_tensor(data_std),
270+
)
271+
vis_output_expr[f"target_{hour}h"] = lambda d, hour=hour: d[f"target_{hour}h"]
272+
# set visualizer
273+
visualizer = {
274+
"visulize_wind": ppsci.visualize.VisualizerWeather(
275+
vis_datas,
276+
vis_output_expr,
277+
xticks=np.linspace(0, 1439, 13),
278+
xticklabels=[str(i) for i in range(360, -1, -30)],
279+
yticks=np.linspace(0, 719, 7),
280+
yticklabels=[str(i) for i in range(90, -91, -30)],
281+
vmin=0,
282+
vmax=25,
283+
colorbar_label="m\s",
284+
batch_size=1,
285+
num_timestamps=num_timestamps,
286+
prefix="wind",
287+
)
288+
}
289+
290+
# directly evaluate pretrained model
291+
logger.init_logger("ppsci", f"{output_dir}/eval.log", "info")
292+
solver = ppsci.solver.Solver(
293+
model,
294+
output_dir=output_dir,
295+
log_freq=1,
296+
validator=validator,
297+
visualizer=visualizer,
298+
pretrained_model_path=f"{output_dir}/checkpoints/latest",
299+
compute_metric_by_batch=True,
300+
eval_with_no_grad=True,
301+
)
302+
solver.eval()
303+
# visualize prediction from pretrained_model_path
304+
solver.visualize()

0 commit comments

Comments
 (0)