Skip to content

Commit 567881d

Browse files
1. add PrecipNet;
2. add train_precip.py for fourcastnet
1 parent 7baf04a commit 567881d

File tree

3 files changed

+438
-0
lines changed

3 files changed

+438
-0
lines changed

docs/zh/api/arch.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
- PhysformerGPT2
1212
- ModelList
1313
- AFNONet
14+
- PrecipNet
1415
show_root_heading: false
1516
heading_level: 3

examples/fourcastnet/train_precip.py

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import partial
16+
from typing import Tuple
17+
18+
import h5py
19+
import numpy as np
20+
import paddle
21+
import paddle.distributed as dist
22+
23+
import examples.fourcastnet.utils as fourcast_utils
24+
import ppsci
25+
from ppsci.utils import config
26+
from ppsci.utils import logger
27+
28+
29+
def get_vis_datas(
30+
wind_file_path: str,
31+
file_path: str,
32+
date_strings: Tuple[str, ...],
33+
num_timestamps: int,
34+
vars_channel: Tuple[int, ...],
35+
img_h: int,
36+
data_mean: np.ndarray,
37+
data_std: np.ndarray,
38+
):
39+
__wind_file = h5py.File(wind_file_path, "r")["fields"]
40+
_file = h5py.File(file_path, "r")["tp"]
41+
wind_data = []
42+
data = []
43+
for date_str in date_strings:
44+
hours_since_jan_01_epoch = fourcast_utils.date_to_hours(date_str)
45+
ic = int(hours_since_jan_01_epoch / 6)
46+
wind_data.append(__wind_file[ic, vars_channel, 0:img_h])
47+
data.append(_file[ic + 1 : ic + num_timestamps + 1, 0:img_h])
48+
wind_data = np.asarray(wind_data)
49+
data = np.asarray(data)
50+
51+
vis_datas = {"input": (wind_data - data_mean) / data_std}
52+
for t in range(num_timestamps):
53+
hour = (t + 1) * 6
54+
data_t = data[:, t]
55+
vis_datas[f"target_{hour}h"] = np.asarray(data_t)
56+
return vis_datas
57+
58+
59+
if __name__ == "__main__":
60+
args = config.parse_args()
61+
# set random seed for reproducibility
62+
ppsci.utils.set_random_seed(1024)
63+
# Initialize distributed environment
64+
dist.init_parallel_env()
65+
66+
# set wind dataset path
67+
wind_train_file_path = "./datasets/era5/train"
68+
wind_valid_file_path = "./datasets/era5/test"
69+
wind_test_file_path = "./datasets/era5/out_of_sample/2018.h5"
70+
wind_mean_path = "./datasets/era5/stat/global_means.npy"
71+
wind_std_path = "./datasets/era5/stat/global_stds.npy"
72+
wind_time_mean_path = "./datasets/era5/stat/time_means.npy"
73+
# set dataset path
74+
train_file_path = "./datasets/era5/precip/train"
75+
valid_file_path = "./datasets/era5/precip/test"
76+
test_file_path = "./datasets/era5/precip/out_of_sample/2018.h5"
77+
time_mean_path = "./datasets/era5/stat/precip/time_means.npy"
78+
79+
# set training hyper-parameters
80+
input_keys = ("input",)
81+
output_keys = ("output",)
82+
img_h, img_w = 720, 1440
83+
epochs = 25 if not args.epochs else args.epochs
84+
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
85+
# The variable name is 'u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z000',
86+
# 'u850', 'v850', 'z850', 'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv'.
87+
# You can obtain detailed information about each variable from
88+
# https://cds.climate.copernicus.eu/cdsapp#!/search?text=era5&type=dataset
89+
vars_channel = list(range(20))
90+
# set output directory
91+
output_dir = (
92+
"./output/fourcastnet/precip" if not args.output_dir else args.output_dir
93+
)
94+
wind_model_path = "./output/fourcastnet/finetune/checkpoints/latest"
95+
pretrained_model_path = "/root/ssd3/zhangzhimin04/workspaces/FourCastNet_Paddle/model_precip/00/training_checkpoints/best_ckpt"
96+
# initialize logger
97+
logger.init_logger("ppsci", f"{output_dir}/train.log", "info")
98+
99+
wind_data_mean, wind_data_std = fourcast_utils.get_mean_std(
100+
wind_mean_path, wind_std_path, vars_channel
101+
)
102+
data_time_mean = fourcast_utils.get_time_mean(time_mean_path, img_h, img_w)
103+
104+
# set train transforms
105+
transforms = [
106+
{"SqueezeData": {}},
107+
{"CropData": {"xmin": (0, 0), "xmax": (img_h, img_w)}},
108+
{
109+
"Normalize": {
110+
"mean": wind_data_mean,
111+
"std": wind_data_std,
112+
"apply_keys": ("input",),
113+
}
114+
},
115+
{"Log1p": {"scale": 1e-5, "apply_keys": ("label",)}},
116+
]
117+
118+
# set train dataloader config
119+
train_dataloader_cfg = {
120+
"dataset": {
121+
"name": "ERA5Dataset",
122+
"file_path": wind_train_file_path,
123+
"input_keys": input_keys,
124+
"label_keys": output_keys,
125+
"vars_channel": vars_channel,
126+
"precip_file_path": train_file_path,
127+
"transforms": transforms,
128+
},
129+
"sampler": {
130+
"name": "BatchSampler",
131+
"drop_last": True,
132+
"shuffle": True,
133+
},
134+
"batch_size": 1,
135+
"num_workers": 8,
136+
}
137+
# set constraint
138+
sup_constraint = ppsci.constraint.SupervisedConstraint(
139+
train_dataloader_cfg,
140+
ppsci.loss.L2RelLoss(),
141+
name="Sup",
142+
)
143+
constraint = {sup_constraint.name: sup_constraint}
144+
145+
# set iters_per_epoch by dataloader length
146+
iters_per_epoch = len(sup_constraint.data_loader)
147+
148+
# set eval dataloader config
149+
eval_dataloader_cfg = {
150+
"dataset": {
151+
"name": "ERA5Dataset",
152+
"file_path": wind_valid_file_path,
153+
"input_keys": input_keys,
154+
"label_keys": output_keys,
155+
"vars_channel": vars_channel,
156+
"precip_file_path": valid_file_path,
157+
"transforms": transforms,
158+
"training": False,
159+
},
160+
"sampler": {
161+
"name": "BatchSampler",
162+
"drop_last": False,
163+
"shuffle": False,
164+
},
165+
"batch_size": 1,
166+
}
167+
168+
# set metirc
169+
metric = {
170+
"MAE": ppsci.metric.MAE(keep_batch=True),
171+
"LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
172+
num_lat=img_h, keep_batch=True, unlog=True
173+
),
174+
"LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
175+
num_lat=img_h, mean=data_time_mean, keep_batch=True, unlog=True
176+
),
177+
}
178+
179+
# set validator
180+
sup_validator = ppsci.validate.SupervisedValidator(
181+
eval_dataloader_cfg,
182+
ppsci.loss.L2RelLoss(),
183+
metric=metric,
184+
name="Sup_Validator",
185+
)
186+
validator = {sup_validator.name: sup_validator}
187+
188+
# set model
189+
wind_model = ppsci.arch.AFNONet(input_keys, output_keys)
190+
ppsci.utils.save_load.load_pretrain(wind_model, path=wind_model_path)
191+
model = ppsci.arch.PrecipNet(input_keys, output_keys, wind_model=wind_model)
192+
193+
# init optimizer and lr scheduler
194+
lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(
195+
epochs,
196+
iters_per_epoch,
197+
5e-4,
198+
by_epoch=True,
199+
)()
200+
optimizer = ppsci.optimizer.Adam(lr_scheduler)((model,))
201+
202+
# initialize solver
203+
solver = ppsci.solver.Solver(
204+
model,
205+
constraint,
206+
output_dir,
207+
optimizer,
208+
lr_scheduler,
209+
epochs,
210+
iters_per_epoch,
211+
eval_during_train=True,
212+
log_freq=1,
213+
validator=validator,
214+
compute_metric_by_batch=True,
215+
eval_with_no_grad=True,
216+
)
217+
# train model
218+
solver.train()
219+
# evaluate after finished training
220+
solver.eval()
221+
222+
# set testing hyper-parameters
223+
num_timestamps = 6
224+
output_keys = tuple([f"output_{i}" for i in range(num_timestamps)])
225+
226+
# set model for testing
227+
model = ppsci.arch.PrecipNet(
228+
input_keys, output_keys, num_timestamps=num_timestamps, wind_model=wind_model
229+
)
230+
231+
# update eval dataloader config
232+
eval_dataloader_cfg["dataset"].update(
233+
{
234+
"file_path": wind_test_file_path,
235+
"label_keys": output_keys,
236+
"precip_file_path": test_file_path,
237+
"num_label_timestamps": num_timestamps,
238+
"stride": 8,
239+
}
240+
)
241+
242+
# set validator for testing
243+
sup_validator = ppsci.validate.SupervisedValidator(
244+
eval_dataloader_cfg,
245+
ppsci.loss.L2RelLoss(),
246+
metric=metric,
247+
name="Sup_Validator",
248+
)
249+
validator = {sup_validator.name: sup_validator}
250+
251+
# set set visualizer datas
252+
date_strings = ("2018-04-04 00:00:00",)
253+
vis_datas = get_vis_datas(
254+
wind_test_file_path,
255+
test_file_path,
256+
date_strings,
257+
num_timestamps,
258+
vars_channel,
259+
img_h,
260+
wind_data_mean,
261+
wind_data_std,
262+
)
263+
264+
def output_precip_func(d, var_name):
265+
output = 1e-2 * paddle.expm1(d[var_name][0])
266+
return output
267+
268+
visu_output_expr = {}
269+
for i in range(num_timestamps):
270+
hour = (i + 1) * 6
271+
visu_output_expr[f"output_{hour}h"] = partial(
272+
output_precip_func,
273+
var_name=f"output_{i}",
274+
)
275+
visu_output_expr[f"target_{hour}h"] = (
276+
lambda d, hour=hour: d[f"target_{hour}h"] * 1000
277+
)
278+
# set visualizer
279+
visualizer = {
280+
"visulize_precip": ppsci.visualize.VisualizerWeather(
281+
vis_datas,
282+
visu_output_expr,
283+
xticks=np.linspace(0, 1439, 13),
284+
xticklabels=[str(i) for i in range(360, -1, -30)],
285+
yticks=np.linspace(0, 719, 7),
286+
yticklabels=[str(i) for i in range(90, -91, -30)],
287+
vmin=0.001,
288+
vmax=130,
289+
colorbar_label="mm",
290+
log_norm=True,
291+
batch_size=1,
292+
num_timestamps=num_timestamps,
293+
prefix="precip",
294+
)
295+
}
296+
297+
# directly evaluate pretrained model
298+
logger.init_logger("ppsci", f"{output_dir}/eval.log", "info")
299+
solver = ppsci.solver.Solver(
300+
model,
301+
output_dir=output_dir,
302+
log_freq=1,
303+
validator=validator,
304+
visualizer=visualizer,
305+
pretrained_model_path=pretrained_model_path,
306+
compute_metric_by_batch=True,
307+
eval_with_no_grad=True,
308+
)
309+
solver.eval()
310+
# visualize prediction from pretrained_model_path
311+
solver.visualize()

0 commit comments

Comments
 (0)