Skip to content

Commit a4e842a

Browse files
update fourcastnet doc according to comments
1 parent 6787288 commit a4e842a

File tree

4 files changed

+28
-24
lines changed

4 files changed

+28
-24
lines changed

docs/zh/examples/fourcastnet.md

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
## 1. 背景简介
44

5-
在天气预报任务中,有基于物理信息驱动和数据驱动两种方法实现天气预报。基于物理信息驱动的方法,例如 IFS 模型,往往依赖物理方程,通过建模大气变量之间的物理关系实现天气预报;在 IFS 这个模型当中,使用了分布在 50 多个垂直高度上的总共 150 多个大气变量,去实现大气变量的预测工作。基于数据驱动的方法不依赖物理方程,而需要大量的训练数据,一般将神经网络看作一个黑盒结构,训练网络学习输入数据与输出数据之间的函数函数关系,实现给定输入条件下对于输出数据的预测。FourCastNet 算法是基于数据驱动的方法,相比于 IFS 模型,它仅仅使用了 5 个垂直高度上的 20 个大气变量进行模型的训练,使用的大气变量的个数要少很多,而且推理速度更快
5+
在天气预报任务中,有基于物理信息驱动和数据驱动两种方法实现天气预报。基于物理信息驱动的方法,往往依赖物理方程,通过建模大气变量之间的物理关系实现天气预报。例如在 IFS 模型中,使用了分布在 50 多个垂直高度上共 150 多个大气变量实现天气的预测。基于数据驱动的方法不依赖物理方程,但是需要大量的训练数据,一般将神经网络看作一个黑盒结构,训练网络学习输入数据与输出数据之间的函数关系,实现给定输入条件下对于输出数据的预测。FourCastNet 算法是基于数据驱动方法的,相比于 IFS 模型,它仅仅使用了 5 个垂直高度上共 20 个大气变量,具有大气变量输入个数少,推理理速度快的特点
66

77
## 2. 模型原理
88

9-
本章节仅对 FourCastNet 的模型原理进行简单的介绍,详细的理论推导请阅读[论文](https://arxiv.org/abs/2202.11214)
9+
本章节仅对 FourCastNet 的模型原理进行简单地介绍,详细的理论推导请阅读 [FourCastNet: A Global Data-driven High-resolution Weather Model using Adaptive Fourier Neural Operators](https://arxiv.org/abs/2202.11214)
1010

11-
FourCastNet 的网络模型使用了 AFNO 网络,这是一个应用在图像分割领域的一个网络。这个网络基于 FNO 和 VIT 这两个方法,借鉴了 FNO 中使用傅立叶变换完成不同 token 信息交互的方法,缓解了 VIT 中 self-attention 计算量大的问题,这个问题在高分辨率输入数据的情况下尤为明显。关于 [AFNO](https://openreview.net/pdf?id=EXHG-A3jlM)[FNO](https://arxiv.org/abs/2010.08895)[VIT](https://arxiv.org/pdf/2010.11929.pdf) 的相关原理也请阅读对应论文。
11+
FourCastNet 的网络模型使用了 AFNO 网络,该网络此前常用于图像分割任务。这个网络通过 FNO 弥补了 ViT 网络的缺点,使用傅立叶变换完成不同 token 信息交互,显著减少了高分辨率下 ViT 中 self-attention 的计算量。关于 [AFNO](https://openreview.net/pdf?id=EXHG-A3jlM)[FNO](https://arxiv.org/abs/2010.08895)[VIT](https://arxiv.org/pdf/2010.11929.pdf) 的相关原理也请阅读对应论文。
1212

1313
模型的总体结构如图所示:
1414

@@ -30,14 +30,14 @@ FourCastNet论文中训练了风速模型和降雨量模型,接下来将介绍
3030
<figcaption>风速模型预训练</figcaption>
3131
</figure>
3232

33-
模型训练的第二个阶段是模型微调,这个阶段的训练主要是为了提高模型在中长期天气预报的精度。具体的,当模型输入 $k$ 时刻的数据,预测了 $k+1$ 时刻的数据后,再使用预测的得到的 $k+1$ 时刻的数据作为模型输入预测 $k+2$ 时刻的数据,这样把预测得到的两个时刻的数据都用真值进行约束,提高了模型长时预测的能力
33+
模型训练的第二个阶段是模型微调,这个阶段的训练主要是为了提高模型在中长期天气预报的精度。具体地,当模型输入 $k$ 时刻的数据,预测了 $k+1$ 时刻的数据后,再将其重新作为输入预测 $k+2$ 时刻的数据,以连续预测两个时刻的训练方式,提高模型长时预测能力
3434

3535
<figure markdown>
3636
![fourcastnet-finetuning](../../images/fourcastnet/finetuning.png){ loading=lazy style="margin:0 auto;height:40%;width:40%"}
3737
<figcaption>风速模型微调</figcaption>
3838
</figure>
3939

40-
在推理阶段,给定 $k$ 时刻的数据,可以通过不断迭代,得到 $k+1$、$k+1$、$k+3$ 等时刻的预测结果。
40+
在推理阶段,给定 $k$ 时刻的数据,可以通过不断迭代,得到 $k+1$、$k+2$、$k+3$ 等时刻的预测结果。
4141

4242
<figure markdown>
4343
![fourcastnet-inference](../../images/fourcastnet/wind_inference.png){ loading=lazy style="margin:0 auto;height:40%;width:40%"}
@@ -46,16 +46,16 @@ FourCastNet论文中训练了风速模型和降雨量模型,接下来将介绍
4646

4747
### 2.2 降雨量模型的训练、推理过程
4848

49-
降雨量模型的训练依赖于风速模型,如下图所示,使用 $k$ 时刻的大气变量数据 $X(k)$ 输入训练好的风速模型,得到风速模型预测的 $k+1$ 时刻的大气变量数据 $X(k+1)$。降雨量模型以 $X(k+1)$ 为输入,输出为预测的 $k+1$ 时刻的降雨量数据 $p(k+1)$。模型训练时 $p(k+1)$ 与真值数据 $p_{true}(k+1)$ 计算 L2 损失函数约束网络训练。
49+
降雨量模型的训练依赖于风速模型,如下图所示,使用 $k$ 时刻的大气变量数据 $X(k)$ 输入训练好的风速模型,得到预测的 $k+1$ 时刻的大气变量数据 $X(k+1)$。降雨量模型以 $X(k+1)$ 为输入,输出为 $k+1$ 时刻的降雨量预测结果 $p(k+1)$。模型训练时 $p(k+1)$ 与真值数据 $p_{true}(k+1)$ 计算 L2 损失函数约束网络训练。
5050

5151
<figure markdown>
5252
![precip-training](../../images/fourcastnet/precip_training.png){ loading=lazy style="margin:0 auto;height:40%;width:40%"}
5353
<figcaption>降雨量模型训练</figcaption>
5454
</figure>
5555

56-
需要注意的是在降雨量模型的训练过程中,风速模型的参数处于冻结状态不参与梯度传播和反向计算
56+
需要注意的是在降雨量模型的训练过程中,风速模型的参数处于冻结状态,不参与优化器参数更新过程
5757

58-
在推理阶段,给定 $k$ 时刻的数据,可以通过不断迭代,利用风速模型得到 $k+1$、$k+1$、$k+3$ 等时刻的大气变量预测数据,然后使用该数据作为降雨量模型的输入预测对应时刻的降雨量数据
58+
在推理阶段,给定 $k$ 时刻的数据,可以通过不断迭代,利用风速模型得到 $k+1$、$k+2$、$k+3$ 等时刻的大气变量预测结果,作为降雨量模型的输入,预测对应时刻的降雨量
5959

6060
<figure markdown>
6161
![precip-inference](../../images/fourcastnet/precip_inference.png){ loading=lazy style="margin:0 auto;height:40%;width:40%"}
@@ -64,11 +64,15 @@ FourCastNet论文中训练了风速模型和降雨量模型,接下来将介绍
6464

6565
## 3. 风速模型实现
6666

67-
接下来开始讲解如何基于 PaddleScience 代码,实现 FourCastNet 风速模型的训练与推理。由于完整复现需要 5+TB 的存储空间和 64 卡的训练资源,需要的资源非常大,因此如果仅仅是学习 FourCastNet 的算法原理,非常建议对训练数据集进行缩减(例如仅使用近五年的训练数据)以减小学习成本。接下来首先会对使用的数据集进行介绍,然后对该方法两个训练步骤(模型预训练、模型微调)的监督约束构建、模型构建等进行阐述。关于该案例中的其余细节请参考 [API文档](../api/arch.md)
67+
接下来开始讲解如何基于 PaddleScience 代码,实现 FourCastNet 风速模型的训练与推理。关于该案例中的其余细节请参考 [API文档](../api/arch.md)
68+
69+
???+ Info
70+
71+
由于完整复现需要 5+TB 的存储空间和 64 卡的训练资源,因此如果仅仅是为了学习 FourCastNet 的算法原理,建议对一小部分训练数据集进行训练,以减小学习成本。
6872

6973
### 3.1 数据集介绍
7074

71-
数据集采用了 [FourCastNet](https://github.com/NVlabs/FourCastNet) 中处理好的 ERA5 数据集。该数据集的分辨率大小为 0.25,每个变量的数据大小为 720*1440,其中单个数据点代表的实际距离为 30km 左右。FourCastNet 使用了 1979-2018 年的数据,根据年份划分为了训练集、验证集、测试集,划分结果如下:
75+
数据集采用了 [FourCastNet](https://github.com/NVlabs/FourCastNet) 中处理好的 ERA5 数据集。该数据集的分辨率大小为 0.25 度,每个变量的数据尺寸为 $720 \times 1440$,其中单个数据点代表的实际距离为 30km 左右。FourCastNet 使用了 1979-2018 年的数据,根据年份划分为了训练集、验证集、测试集,划分结果如下:
7276

7377
|数据集 |年份 |
7478
|:----:|:---------:|
@@ -81,11 +85,11 @@ FourCastNet论文中训练了风速模型和降雨量模型,接下来将介绍
8185
模型训练使用了分布在 5 个压力层上的 20 个大气变量,如下表所示,
8286

8387
<figure markdown>
84-
![fourcastnet-vars](../../images/fourcastnet/era5-vars.png){ loading=lazy }
88+
![fourcastnet-vars](../../images/fourcastnet/era5-vars.png){ loading=lazy style="margin:0 auto;height:60%;width:60%"}
8589
<figcaption>20 个大气变量</figcaption>
8690
</figure>
8791

88-
其中 $T$、$U$、$V$ 、$Z$、$RH$ 分别代表指定垂直高度上的温度、纬向风速、经向风速、位势和相对湿度;$U_{10}$/$V_{10}$、$T_{2m}$则代表距离地面 10 米的纬向风速/经向风速和距离地面 2 米的温度。$sp$ 代表地面气压,$mslp$ 代表平均海平面气压。$TCWV$ 代表整层气柱水汽总量。
92+
其中 $T$、$U$、$V$ 、$Z$、$RH$ 分别代表指定垂直高度上的温度、纬向风速、经向风速、位势和相对湿度;$U_{10}$$V_{10}$、$T_{2m}$ 则代表距离地面 10 米的纬向风速经向风速和距离地面 2 米的温度。$sp$ 代表地面气压,$mslp$ 代表平均海平面气压。$TCWV$ 代表整层气柱水汽总量。
8993

9094
对每天 24 个小时的数据间隔 6 小时采样,得到 0.00h/6.00h/12.00h/18.00h 时刻全球 20 个大气变量的数据,使用这样的数据进行模型的训练与推理。即输入0.00h 时刻的 20 个大气变量的数据,模型输出预测得到的 6.00h 时刻的 20 个大气变量的数据。
9195

@@ -112,15 +116,15 @@ examples/fourcastnet/train_pretrain.py:55:69
112116
数据预处理部分总共包含 3 个预处理方法,分别是:
113117

114118
1. `SqueezeData`: 对训练数据的维度进行压缩,如果输入数据的维度为 4,则将第 0 维和第 1 维的数据压缩到一起,最终将输入数据的维度变换为 3。
115-
2. `CropData`: 从训练数据中裁剪指定位置的数据。因为 ERA5 数据集中的原始数据大小为 721*1440,本案例根据原始论文设置,将训练数据裁剪为 720*1440。
119+
2. `CropData`: 从训练数据中裁剪指定位置的数据。因为 ERA5 数据集中的原始数据尺寸为 $721 \times 1440$,本案例根据原始论文设置,将训练数据裁剪为 $720 \times 1440$
116120
3. `Normalize`: 根据训练数据集上的均值、方差对数据进行归一化处理。
117121

118-
由于完整复现 FourCastNet 需要 5+TB 的存储空间和 64 卡的 GPU 资源,需要的存储资源比较多,因此有以下两种训练方式(实验证明两种训练方式的损失函数收敛曲线基本一致,当存储资源比较有限时,可以使用方式 b)。
122+
由于完整复现 FourCastNet 需要 5TB+ 的存储空间和 64 卡的 GPU 资源,需要的存储资源比较多,因此有以下两种训练方式(实验证明两种训练方式的损失函数收敛曲线基本一致,当存储资源比较有限时,可以使用方式 b)。
119123

120-
方式 a: 当存储资源充足时,可以将全部训练数据存储到每台训练机器上(每台机器需要 5+TB 的存储空间),启动训练程序进行训练,此时训练过程中数据的加载是使用全局 shuffle 的方式进行,如下图所示,每个 batch 中的训练数据数据是随机从全量数据集中抽取样本组成的
124+
方式 a: 当存储资源充足时,可以不对数据进行划分,每个节点都有一份完整5TB+的训练数据,然后直接启动训练程序进行训练,此时每个节点上的数据随机抽取自完整训练数据。本方式的训练数据的加载是使用全局 shuffle 的方式进行,如下图所示。
121125

122126
<figure markdown>
123-
![fourcastnet-vars](../../images/fourcastnet/fourcastnet_global_shuffle.png){ loading=lazy }
127+
![fourcastnet-vars](../../images/fourcastnet/fourcastnet_global_shuffle.png){ loading=lazy style="margin:0 auto;height:60%;width:60%"}
124128
<figcaption>全局 shuffle</figcaption>
125129
</figure>
126130

@@ -134,10 +138,10 @@ examples/fourcastnet/train_pretrain.py:74:90
134138

135139
其中,"dataset" 字段定义了使用的 `Dataset` 类名为 `ERA5Dataset`,"sampler" 字段定义了使用的 `Sampler` 类名为 `BatchSampler`,设置的 `batch_size` 为 1,`num_works` 为 8。
136140

137-
方式 b:在存储资源有限时,需要将数据集平均分配存储到每个训练机器上,本案例提供了随机采样数据的程序可以使用 ppsci/fourcastnet/sample_data.py,可以根据需要进行修改。本案例默认使用方式 a, 因此使用方式 b 进行模型训练时需要手动将 `USE_SAMPLED_DATA` 设置为 `True`方式 b 训练过程中数据的加载是使用局部 shuffle 的方式进行,如下图所示,将训练数据平均分配到了 8 台机器上,训练时从每台机器上随机抽取 1/8 个 batch 的数据组成 1 个完整的 batch 进行训练。在 8 机条件下,每台机器需要约 1.2TB 的存储空间,相比于方式 a,方式 b 大大减小了对存储空间的依赖。
141+
方式 b:在存储资源有限时,需要将数据集均匀切分至每个节点上,本案例提供了随机采样数据的程序,可以执行 `ppsci/fourcastnet/sample_data.py`,可以根据需要进行修改。本案例默认使用方式 a, 因此使用方式 b 进行模型训练时需要手动将 `USE_SAMPLED_DATA` 设置为 `True`本方式的训练数据的加载是使用局部 shuffle 的方式进行,如下图所示,首先将训练数据平均切分至 8 个节点上,训练时每个节点的数据随机抽取自被切分到的数据上,在这一情况下,每个节点需要约 1.2TB 的存储空间,相比于方式 a,方式 b 大大减小了对存储空间的依赖。
138142

139143
<figure markdown>
140-
![fourcastnet-vars](../../images/fourcastnet/fourcastnet_local_shuffle.png){ loading=lazy }
144+
![fourcastnet-vars](../../images/fourcastnet/fourcastnet_local_shuffle.png){ loading=lazy style="margin:0 auto;height:60%;width:60%"}
141145
<figcaption>局部 shuffle</figcaption>
142146
</figure>
143147

@@ -324,7 +328,7 @@ examples/fourcastnet/train_precip.py:98:115
324328
数据预处理部分总共包含 4 个预处理方法,分别是:
325329

326330
1. `SqueezeData`: 对训练数据的维度进行压缩,如果输入数据的维度为 4,则将第 0 维和第 1 维的数据压缩到一起,最终将输入数据的维度变换为 3。
327-
2. `CropData`: 从训练数据中裁剪指定位置的数据。因为 ERA5 数据集中的原始数据大小为 721*1440,本案例根据原始论文设置,将训练数据裁剪为 720*1440。
331+
2. `CropData`: 从训练数据中裁剪指定位置的数据。因为 ERA5 数据集中的原始数据尺寸为 $721 \times 1440$,本案例根据原始论文设置,将训练数据尺寸裁剪为 $720 \times 1440$
328332
3. `Normalize`: 根据训练数据集上的均值、方差对数据进行归一化处理,这里通过 `apply_keys` 字段设置了该预处理方法仅仅应用到输入数据上。
329333
4. `Log1p`: 将数据映射到对数空间,这里通过 `apply_keys` 字段设置了该预处理方法仅仅应用到真值数据上。
330334

@@ -366,7 +370,7 @@ examples/fourcastnet/train_precip.py:187:190
366370

367371
### 4.3 学习率与优化器构建
368372

369-
本案例中使用的学习率方法为 `Cosine`,学习率大小设置为 5e-4。优化器使用 `Adam`,用 PaddleScience 代码表示如下:
373+
本案例中使用的学习率方法为 `Cosine`,学习率大小设置为 2.5e-4。优化器使用 `Adam`,用 PaddleScience 代码表示如下:
370374

371375
``` py linenums="192" title="examples/fourcastnet/train_precip.py"
372376
--8<--

examples/fourcastnet/train_finetune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_vis_datas(
7272
# set training hyper-parameters
7373
NUM_TIMESTAMPS = 2
7474
input_keys = ("input",)
75-
output_keys = tuple([f"output_{i}" for i in range(NUM_TIMESTAMPS)])
75+
output_keys = tuple(f"output_{i}" for i in range(NUM_TIMESTAMPS))
7676
IMG_H, IMG_W = 720, 1440
7777
EPOCHS = 50 if not args.epochs else args.epochs
7878
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
@@ -215,7 +215,7 @@ def get_vis_datas(
215215

216216
# set testing hyper-parameters
217217
NUM_TIMESTAMPS = 32
218-
output_keys = tuple([f"output_{i}" for i in range(NUM_TIMESTAMPS)])
218+
output_keys = tuple(f"output_{i}" for i in range(NUM_TIMESTAMPS))
219219

220220
# set model for testing
221221
model = ppsci.arch.AFNONet(input_keys, output_keys, num_timestamps=NUM_TIMESTAMPS)

examples/fourcastnet/train_precip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def get_vis_datas(
219219

220220
# set testing hyper-parameters
221221
NUM_TIMESTAMPS = 6
222-
output_keys = tuple([f"output_{i}" for i in range(NUM_TIMESTAMPS)])
222+
output_keys = tuple(f"output_{i}" for i in range(NUM_TIMESTAMPS))
223223

224224
# set model for testing
225225
model = ppsci.arch.PrecipNet(

mkdocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ nav:
4141
- Cylinder2D_unsteady: zh/examples/cylinder2d_unsteady.md
4242
- Laplace2D: zh/examples/laplace2d.md
4343
- 数据驱动:
44+
- FourCastNet: zh/examples/fourcastnet.md
4445
- Lorenz_transform_physx: zh/examples/lorenz.md
4546
- Rossler_transform_physx: zh/examples/rossler.md
4647
- Cylinder2D_unsteady_transform_physx: zh/examples/cylinder2d_unsteady_transformer_physx.md
47-
- FourCastNet: zh/examples/fourcastnet.md
4848
- API文档:
4949
- ppsci.arch: zh/api/arch.md
5050
- ppsci.autodiff: zh/api/autodiff.md

0 commit comments

Comments
 (0)