@@ -23,21 +23,23 @@ class ParallelDims:
23
23
cp : int
24
24
tp : int
25
25
pp : int
26
+ ep : int
26
27
world_size : int
27
28
enable_loss_parallel : bool
28
29
29
30
def __post_init__ (self ):
30
31
self ._validate ()
31
32
32
33
def _validate (self ):
33
- dp_replicate , dp_shard , cp , tp , pp = (
34
+ dp_replicate , dp_shard , cp , tp , pp , ep = (
34
35
self .dp_replicate ,
35
36
self .dp_shard ,
36
37
self .cp ,
37
38
self .tp ,
38
39
self .pp ,
40
+ self .ep ,
39
41
)
40
- for d in (dp_replicate , cp , tp , pp ):
42
+ for d in (dp_replicate , cp , tp , pp , ep ):
41
43
assert d >= 1 , "Parallelism degree should be >= 1, except for dp_shard"
42
44
43
45
assert dp_shard == - 1 or dp_shard >= 1 , " dp_shard must -1 or >=1."
@@ -50,7 +52,78 @@ def _validate(self):
50
52
f"cp({ cp } ) * tp({ tp } ) * pp({ pp } ) != WORLD_SIZE({ self .world_size } )"
51
53
)
52
54
55
+ if ep > 1 :
56
+ # EP would borrow all cp and some dp_shard degree
57
+ assert ep % cp == 0 and (dp_shard * cp ) % ep == 0
58
+
59
+ def _build_mesh_with_ep (self , device_type ):
60
+ # With ep, dp_shard and ep are derived submeshes:
61
+ # dp_shard = dp_shard_mod_ep * dp_shard_in_ep
62
+ # ep = dp_shard_in_ep * cp
63
+ dp_shard_mod_ep = self .dp_shard * self .cp // self .ep
64
+ dp_shard_in_ep = self .ep // self .cp
65
+
66
+ dims = []
67
+ names = []
68
+ for d , name in zip (
69
+ [
70
+ self .pp ,
71
+ self .dp_replicate ,
72
+ dp_shard_mod_ep ,
73
+ dp_shard_in_ep ,
74
+ self .cp ,
75
+ self .tp ,
76
+ ],
77
+ ["pp" , "dp_replicate" , "dp_shard_mod_ep" , "dp_shard_in_ep" , "cp" , "tp" ],
78
+ ):
79
+ # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping
80
+ # helps the MoE layers do mixed precision training
81
+ if d > 1 or name == "dp_shard_mod_ep" :
82
+ dims .append (d )
83
+ names .append (name )
84
+
85
+ logger .info (f"Building { len (dims )} -D device mesh with { names } , { dims } " )
86
+ mesh = init_device_mesh (device_type , dims , mesh_dim_names = names )
87
+
88
+ # Create all the submesh here to ensure all required process groups are
89
+ # initialized:
90
+ # Mesh for data loading (no communication on this mesh)
91
+ dp_mesh_dim_names = []
92
+ # Mesh for param sharding
93
+ dp_shard_cp_mesh_dim_names = []
94
+ # Mesh for loss all-reduce
95
+ dp_cp_mesh_dim_names = []
96
+ # Mesh for ep
97
+ ep_mesh_dim_names = []
98
+
99
+ if self .dp_replicate_enabled :
100
+ dp_mesh_dim_names .append ("dp_replicate" )
101
+ dp_cp_mesh_dim_names .append ("dp_replicate" )
102
+ # dp_shard_mod_ep is always needed, even if it's 1
103
+ dp_mesh_dim_names .append ("dp_shard_mod_ep" )
104
+ dp_shard_cp_mesh_dim_names .append ("dp_shard_mod_ep" )
105
+ dp_cp_mesh_dim_names .append ("dp_shard_mod_ep" )
106
+ if "dp_shard_in_ep" in names :
107
+ dp_mesh_dim_names .append ("dp_shard_in_ep" )
108
+ dp_shard_cp_mesh_dim_names .append ("dp_shard_in_ep" )
109
+ dp_cp_mesh_dim_names .append ("dp_shard_in_ep" )
110
+ ep_mesh_dim_names .append ("dp_shard_in_ep" )
111
+ if self .cp_enabled :
112
+ dp_shard_cp_mesh_dim_names .append ("cp" )
113
+ dp_cp_mesh_dim_names .append ("cp" )
114
+ ep_mesh_dim_names .append ("cp" )
115
+
116
+ mesh [tuple (dp_mesh_dim_names )]._flatten (mesh_dim_name = "dp" )
117
+ mesh [tuple (dp_shard_cp_mesh_dim_names )]._flatten (mesh_dim_name = "dp_shard_cp" )
118
+ mesh [tuple (dp_cp_mesh_dim_names )]._flatten (mesh_dim_name = "dp_cp" )
119
+ mesh [tuple (ep_mesh_dim_names )]._flatten (mesh_dim_name = "ep" )
120
+
121
+ return mesh
122
+
53
123
def build_mesh (self , device_type : str ) -> DeviceMesh :
124
+ if self .ep > 1 :
125
+ return self ._build_mesh_with_ep (device_type )
126
+
54
127
dims = []
55
128
names = []
56
129
for d , name in zip (
@@ -143,3 +216,7 @@ def loss_parallel_enabled(self):
143
216
@cached_property
144
217
def non_data_parallel_size (self ):
145
218
return self .cp * self .tp * self .pp
219
+
220
+ @property
221
+ def ep_enabled (self ):
222
+ return self .ep > 1
0 commit comments