Skip to content

Commit eb294d9

Browse files
Merge pull request #288 from zhiminzhang0830/fourcast
add FourCastNet
2 parents 8d0ca82 + 474acc6 commit eb294d9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2824
-34
lines changed

docs/zh/api/arch.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,7 @@
1010
- CylinderEmbedding
1111
- PhysformerGPT2
1212
- ModelList
13+
- AFNONet
14+
- PrecipNet
1315
show_root_heading: false
1416
heading_level: 3

docs/zh/api/data/dataset.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@
1111
- RosslerDataset
1212
- CSVDataset
1313
- MatDataset
14+
- ERA5Dataset
15+
- ERA5SampledDataset
1416
show_root_heading: false

docs/zh/api/data/process/transform.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,8 @@
66
members:
77
- Translate
88
- Scale
9+
- Normalize
10+
- Log1p
11+
- CropData
12+
- SqueezeData
913
show_root_heading: false

docs/zh/api/loss.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- Loss
88
- L1Loss
99
- L2Loss
10+
- L2RelLoss
1011
- MSELoss
1112
- MSELossWithL2Decay
1213
- IntegralLoss

docs/zh/api/metric.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,7 @@
88
- MSE
99
- RMSE
1010
- L2Rel
11+
- LatitudeWeightedACC
12+
- LatitudeWeightedRMSE
1113
show_root_heading: false
1214
heading_level: 3

docs/zh/api/visualize.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- Visualizer2D
1111
- Visualizer2DPlot
1212
- Visualizer3D
13+
- VisualizerWeather
1314
- save_vtu_from_dict
1415
- save_plot_from_1d_dict
1516
- save_plot_from_3d_dict

examples/fourcastnet/sample_data.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import glob
16+
import os
17+
import shutil
18+
from multiprocessing import Pool
19+
from typing import Any
20+
from typing import Dict
21+
from typing import Tuple
22+
23+
import h5py
24+
from paddle import io
25+
from tqdm import tqdm
26+
27+
import examples.fourcastnet.utils as fourcast_utils
28+
import ppsci
29+
from ppsci.utils import logger
30+
31+
32+
def sample_func(
33+
dataset_cfg: Dict[str, Any], save_path: str, batch_idxs: Tuple[int, ...]
34+
):
35+
dataset = ppsci.data.dataset.build_dataset(dataset_cfg)
36+
for idx in tqdm(batch_idxs):
37+
input_dict, label_dict, weight_dict = dataset[idx]
38+
fdest = h5py.File(f"{save_path}/{idx:0>8d}.h5", "w")
39+
for key, value in input_dict.items():
40+
fdest.create_dataset(f"input_dict/{key}", data=value, dtype="f")
41+
for key, value in label_dict.items():
42+
fdest.create_dataset(f"label_dict/{key}", data=value, dtype="f")
43+
if weight_dict is not None:
44+
for key, value in weight_dict.items():
45+
fdest.create_dataset(f"weight_dict/{key}", data=value, dtype="f")
46+
47+
48+
def sample_data_epoch(epoch: int):
49+
# initialize logger
50+
logger.init_logger("ppsci")
51+
# set dataset path and save path
52+
TRAIN_FILE_PATH = "./datasets/era5/train"
53+
PRECIP_FILE_PATH = None
54+
DATA_MEAN_PATH = "./datasets/era5/stat/global_means.npy"
55+
DATA_STD_PATH = "./datasets/era5/stat/global_stds.npy"
56+
TMP_SAVE_PATH = "./datasets/era5/train_split_rank0/epoch_tmp"
57+
save_path = f"./datasets/era5/train_split_rank0/epoch_{epoch}"
58+
# set hyper-parameters
59+
input_keys = ("input",)
60+
output_keys = ("output",)
61+
IMG_H, IMG_W = 720, 1440
62+
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
63+
# The variable name is 'u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z000',
64+
# 'u850', 'v850', 'z850', 'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv'.
65+
# You can obtain detailed information about each variable from
66+
# https://cds.climate.copernicus.eu/cdsapp#!/search?text=era5&type=dataset
67+
VARS_CHANNEL = list(range(20))
68+
NUM_TRAINER = 1
69+
RANK = 0
70+
PROCESSES = 16
71+
72+
if len(glob.glob(TMP_SAVE_PATH + "/*.h5")):
73+
raise FileExistsError(
74+
f"TMP_SAVE_PATH({TMP_SAVE_PATH}) is not an empty folder, please specify an empty folder."
75+
)
76+
if len(glob.glob(save_path + "/*.h5")):
77+
raise FileExistsError(
78+
f"save_path({save_path}) is not an empty folder, please specify an empty folder."
79+
)
80+
os.makedirs(TMP_SAVE_PATH, exist_ok=True)
81+
82+
data_mean, data_std = fourcast_utils.get_mean_std(
83+
DATA_MEAN_PATH, DATA_STD_PATH, VARS_CHANNEL
84+
)
85+
transforms = [
86+
{"SqueezeData": {}},
87+
{"CropData": {"xmin": (0, 0), "xmax": (IMG_H, IMG_W)}},
88+
{"Normalize": {"mean": data_mean, "std": data_std}},
89+
]
90+
dataset_cfg = {
91+
"name": "ERA5Dataset",
92+
"file_path": TRAIN_FILE_PATH,
93+
"input_keys": input_keys,
94+
"label_keys": output_keys,
95+
"PRECIP_FILE_PATH": PRECIP_FILE_PATH,
96+
"vars_channel": VARS_CHANNEL,
97+
"transforms": transforms,
98+
}
99+
dataset = ppsci.data.dataset.build_dataset(dataset_cfg)
100+
101+
batch_sampler = io.DistributedBatchSampler(
102+
dataset=dataset,
103+
batch_size=1,
104+
shuffle=False,
105+
num_replicas=NUM_TRAINER,
106+
rank=RANK,
107+
)
108+
batch_sampler.set_epoch(epoch)
109+
batch_idxs = []
110+
for data in tqdm(batch_sampler):
111+
batch_idxs += data
112+
113+
pool = Pool(processes=PROCESSES)
114+
for st in range(0, len(batch_idxs), len(batch_idxs) // (PROCESSES - 1)):
115+
end = st + len(batch_idxs) // (PROCESSES - 1)
116+
result = pool.apply_async(
117+
sample_func, (dataset_cfg, TMP_SAVE_PATH, batch_idxs[st:end])
118+
)
119+
pool.close()
120+
pool.join()
121+
if result.successful():
122+
logger.info("successful")
123+
shutil.move(TMP_SAVE_PATH, save_path)
124+
logger.info(f"move {TMP_SAVE_PATH} to {save_path}")
125+
126+
127+
def main():
128+
EPOCHS = 0
129+
sample_data_epoch(EPOCHS)
130+
131+
# if you want to sample every 5 epochs, you can use the following code
132+
# EPOCHS = 150
133+
# for epoch in range(0, EPOCHS, 5):
134+
# sample_data_epoch(epoch)
135+
136+
137+
if __name__ == "__main__":
138+
main()

0 commit comments

Comments
 (0)