42
42
* Shard(axis),指将张量沿 axis 维度做切分后,放到不同的计算设备上。
43
43
* Partial,指每个计算设备只拥有部分值,需要通过指定的规约操作才能恢复成全量数据。
44
44
45
- ![ 三种分布式状态] ( images/placements.png )
45
+ <p align =" center " >
46
+ <img src="images/auto_parallel/mesh.png" width="40%"/>
47
+ </p >
46
48
47
49
在如下的示例中,我们希望在 6 个计算设备上,创建一个形状为(4, 3)的分布式张量,其中沿着计算设备的 x 维,切分张量的 0 维;沿着计算设备的 y 维上,切分张量的 1 维。最终,每个计算设备实际拥有大小为(2, 1)的实际张量,如图所示。
48
50
@@ -60,7 +62,9 @@ dense_tensor = paddle.to_tensor([[1,2,3],
60
62
placements = [dist.Shard(0 ), dist.Shard(1 )]
61
63
dist_tensor = dist.shard_tensor(dense_tensor, mesh, placements)
62
64
```
63
- ![ 切分状态] ( images/shard.png )
65
+ <p align =" center " >
66
+ <img src="images/auto_parallel/shard.png" width="40%"/>
67
+ </p >
64
68
65
69
## 2.3 张量重切分
66
70
@@ -83,7 +87,9 @@ placements1 = [dist.Shard(0)]
83
87
dist_tensor = dist.shard_tensor(dense_tensor, mesh0, placements0)
84
88
dist_tensor_after_reshard = dist.reshard(dist_tensor, mesh1, placements1)
85
89
```
86
- ![ 切分状态] ( images/reshard.png )
90
+ <p align =" center " >
91
+ <img src="images/auto_parallel/reshard.png" width="40%"/>
92
+ </p >
87
93
88
94
# 三、原理简介
89
95
@@ -107,28 +113,28 @@ dist_tensorB = dist.shard_tensor(dense_tensorB, mesh, placementsB)
107
113
dist_tensorC = Matmul(dist_tensorA, dist_tensorB)
108
114
dist_tensorD = relu(dist_tensorC)
109
115
```
110
- < div style = " text-align : center ; " >
111
- <img src = " images/underlying1.png " alt = " 用户标记 " style = " width : 45 % ; height : auto ; center ; " >
112
- <!--  -- >
113
- </div >
116
+
117
+ <p align = " center " >
118
+ <img src=" images/auto_parallel/shard_anonation .png" width="40%"/ >
119
+ </p >
114
120
115
121
接下来就会进入自动并行的第一个核心逻辑 ** 切分推导** 。
116
122
当前用户标记的输入切分状态是无法被 Matmul 算子实际计算的(TensorA 的第 0 维和 TensorB 的第 1 维不匹配)。
117
123
这时候自动并行框架会使用当前算子的切分推导规则(e.g. MatmulSPMD Rule),根据输入 tensors 的切分状态,推导出一套合法且性能较优的 输入-输出 张量的切分状态。
118
124
在上述输入的切分状态下,框架会推导出会将 TensorA 的切分状态推导成按列切分,TensorB 保持切分状态不变,Matmul 的计算结果 TensorC 的切分状态是 Partial。
119
125
因为后续的 Relu 算子是非线性的,输入不能是 Partial 状态,所以框架会根据 ReluSPMD Rule 将 TensorC 输入 Relu 前的的分布式状态推导成 Replicated。
120
- <div style = " text-align : center ; " >
121
- <img src =" images/underlying2 .png " alt = " 切分推导 " style = " width : 45 % ; height : auto ; center ; " >
122
- </div >
126
+ <p align = " center " >
127
+ <img src="images/auto_parallel/shard_propogation .png" width="40%"/ >
128
+ </p >
123
129
124
130
接下来就会进入自动并行的第二个核心逻辑 ** 切分转换** 。
125
131
框架会根据 tensor 当前的切分状态(src_placement),和切分推导规则推导出的算子计算需要的切分状态(dst_placement),添加对应的通信/张量维度变换算子。
126
132
根据上图的切分推导,在计算 Matmul 添加 split 算子,在计算 Relue 添加 Allreduce,将输入 tensor 转换成需要的切分状态进行实际计算。
127
133
128
- <div style = " text-align : center ; " >
129
- <img src =" images/underlying3 .png " alt = " 切分转换 " style = " width : 45 % ; height : auto ; center ; " >
130
- </div >
131
- <!--  -->
134
+ <p align = " center " >
135
+ <img src="images/auto_parallel/shard_convertion .png" width="40%"/ >
136
+ </p >
137
+ <!--  -->
132
138
133
139
134
140
# 四、使用示例
@@ -145,6 +151,7 @@ dist_tensorD = relu(dist_tensorC)
145
151
import paddle
146
152
import paddle.distributed as dist
147
153
from paddle.io import BatchSampler, DataLoader, Dataset
154
+ import numpy as np
148
155
149
156
mesh = dist.ProcessMesh([0 , 1 , 2 , 3 ], dim_names = [' x' ])
150
157
@@ -267,6 +274,7 @@ class MlpModel(paddle.nn.Layer):
267
274
import paddle
268
275
import paddle.distributed as dist
269
276
from paddle.io import BatchSampler, DataLoader, Dataset
277
+ import numpy as np
270
278
271
279
mesh0 = dist.ProcessMesh([[0 , 1 ], [2 , 3 ]], dim_names = [' x' , ' y' ]) # 创建进程网格
272
280
mesh1 = dist.ProcessMesh([[4 , 5 ], [6 , 7 ]], dim_names = [' x' , ' y' ]) # 创建进程网格
@@ -332,7 +340,9 @@ for step, inputs in enumerate(dataloader):
332
340
333
341
自动并行的 API 在设计之初,就以实现统一的用户标记接口和逻辑为目标,保证动静半框架保证在相同的用户标记下,动静态图分布式执行逻辑一致。这样用户在全流程过程中只需要标记一套动态图组网,即可以实现动态图下的分布式训练 Debug 和 静态图下的分布式推理等逻辑。整个动转静训练的逻辑如下:
334
342
335
- ![ 切分状态] ( images/dynamic_static_unified_auto_parallel.png )
343
+ <p align =" center " >
344
+ <img src="images/auto_parallel/dynamic-static-unified.png" width="40%"/>
345
+ </p >
336
346
337
347
``` python
338
348
...
0 commit comments