Skip to content

Commit 792f7a8

Browse files
committed
[WIP] expert parallel dp2ep
1 parent dc7fd23 commit 792f7a8

File tree

16 files changed

+676
-196
lines changed

16 files changed

+676
-196
lines changed

run_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ set -ex
1212
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
1313
NGPU=${NGPU:-"8"}
1414
export LOG_RANK=${LOG_RANK:-0}
15-
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
15+
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/llama4/train_configs/debug_model.toml"}
1616

1717
overrides=""
1818
if [ $# -ne 0 ]; then

scripts/estimate/estimation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def estimate_memory(job_config: JobConfig):
4646
cp=parallelism_config.context_parallel_degree,
4747
tp=parallelism_config.tensor_parallel_degree,
4848
pp=parallelism_config.pipeline_parallel_degree,
49+
ep=parallelism_config.expert_parallel_degree,
4950
world_size=world_size,
5051
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
5152
)
@@ -56,8 +57,9 @@ def estimate_memory(job_config: JobConfig):
5657
or parallel_dims.tp_enabled
5758
or parallel_dims.pp_enabled
5859
or parallel_dims.cp_enabled
60+
or parallel_dims.ep_enabled
5961
):
60-
logger.warning("DDP, TP, PP, CP are not supported yet.")
62+
logger.warning("DDP, TP, PP, CP, EP are not supported yet.")
6163
return
6264
if not parallel_dims.dp_shard_enabled:
6365
logger.warning("FSDP or HSDP is not enabled. Skipping memory estimation.")

scripts/generate/test_generate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def test_generate(
125125
cp=1,
126126
tp=world_size,
127127
pp=1,
128+
ep=1,
128129
world_size=world_size,
129130
enable_loss_parallel=False,
130131
)

tests/unit_tests/test_model_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def build_parallel_dims(job_config, world_size):
2121
cp=parallelism_config.context_parallel_degree,
2222
tp=parallelism_config.tensor_parallel_degree,
2323
pp=parallelism_config.pipeline_parallel_degree,
24+
ep=parallelism_config.expert_parallel_degree,
2425
world_size=world_size,
2526
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
2627
)

torchtitan/components/ft.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,17 +124,7 @@ def func(
124124
manager=self.ft_manager.manager,
125125
)
126126

127-
dims = []
128-
names = []
129-
for d, name in zip(
130-
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
131-
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
132-
):
133-
if d > 1 or name == "dp_replicate":
134-
dims.append(d)
135-
names.append(name)
136-
137-
return self._build_mesh(device_type, dims, names, func)
127+
return self._build_mesh(device_type, func)
138128

139129
@property
140130
def dp_replicate_enabled(self):

torchtitan/config_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,14 @@ class Parallelism:
363363
The default value is 'allgather'.
364364
"""
365365

366+
expert_parallel_degree: int = 1
367+
"""
368+
Expert parallelism degree. 1 means disabled.
369+
Currently, only "dp2ep" is supported, with the following constraints:
370+
context_parallel_degree <= expert_parallel_degree <= data_parallel_shard_degree * context_parallel_degree
371+
Note that this is still an experimental feature.
372+
"""
373+
366374

367375
@dataclass
368376
class Checkpoint:

torchtitan/distributed/parallel_dims.py

Lines changed: 101 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,23 @@ class ParallelDims:
2323
cp: int
2424
tp: int
2525
pp: int
26+
ep: int
2627
world_size: int
2728
enable_loss_parallel: bool
2829

2930
def __post_init__(self):
3031
self._validate()
3132

3233
def _validate(self):
33-
dp_replicate, dp_shard, cp, tp, pp = (
34+
dp_replicate, dp_shard, cp, tp, pp, ep = (
3435
self.dp_replicate,
3536
self.dp_shard,
3637
self.cp,
3738
self.tp,
3839
self.pp,
40+
self.ep,
3941
)
40-
for d in (dp_replicate, cp, tp, pp):
42+
for d in (dp_replicate, cp, tp, pp, ep):
4143
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
4244

4345
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
@@ -50,26 +52,107 @@ def _validate(self):
5052
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
5153
)
5254

55+
if ep > 1:
56+
# EP would borrow all cp and some dp_shard degree
57+
assert ep % cp == 0 and (dp_shard * cp) % ep == 0
58+
5359
def build_mesh(self, device_type: str) -> DeviceMesh:
60+
return self._build_mesh(device_type, init_device_mesh)
61+
62+
def _build_mesh(
63+
self, device_type: str, init_device_mesh_fn: Callable
64+
) -> DeviceMesh:
65+
# TODO: Current implementation of ParallelDims for dp2ep Expert Parallel
66+
# is not very clean, due to the limited support from DeviceMesh
67+
# for creating two staggered meshes. Will improve.
68+
if self.ep > 1:
69+
return self._build_mesh_with_ep(device_type, init_device_mesh_fn)
70+
else:
71+
return self._build_mesh_without_ep(device_type, init_device_mesh_fn)
72+
73+
def _build_mesh_with_ep(
74+
self,
75+
device_type: str,
76+
init_device_mesh_fn: Callable,
77+
) -> DeviceMesh:
78+
# With ep, dp_shard and ep are derived submeshes:
79+
# dp_shard = dp_shard_mod_ep * dp_shard_in_ep
80+
# ep = dp_shard_in_ep * cp
81+
dp_shard_mod_ep = self.dp_shard * self.cp // self.ep
82+
dp_shard_in_ep = self.ep // self.cp
83+
5484
dims = []
5585
names = []
5686
for d, name in zip(
57-
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
58-
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
87+
[
88+
self.pp,
89+
self.dp_replicate,
90+
dp_shard_mod_ep,
91+
dp_shard_in_ep,
92+
self.cp,
93+
self.tp,
94+
],
95+
["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"],
5996
):
60-
if d > 1:
97+
# dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping
98+
# helps the MoE layers do mixed precision training
99+
if d > 1 or name == "dp_shard_mod_ep":
61100
dims.append(d)
62101
names.append(name)
63102

64-
return self._build_mesh(device_type, dims, names, init_device_mesh)
103+
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
104+
mesh = init_device_mesh_fn(device_type, dims, mesh_dim_names=names)
65105

66-
def _build_mesh(
106+
# Create all the submesh here to ensure all required process groups are
107+
# initialized:
108+
# Mesh for data loading (no communication on this mesh)
109+
dp_mesh_dim_names = []
110+
# Mesh for param sharding
111+
dp_shard_cp_mesh_dim_names = []
112+
# Mesh for loss all-reduce
113+
dp_cp_mesh_dim_names = []
114+
# Mesh for ep
115+
ep_mesh_dim_names = []
116+
117+
if self.dp_replicate_enabled:
118+
dp_mesh_dim_names.append("dp_replicate")
119+
dp_cp_mesh_dim_names.append("dp_replicate")
120+
# dp_shard_mod_ep is always needed, even if it's 1
121+
dp_mesh_dim_names.append("dp_shard_mod_ep")
122+
dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep")
123+
dp_cp_mesh_dim_names.append("dp_shard_mod_ep")
124+
if "dp_shard_in_ep" in names:
125+
dp_mesh_dim_names.append("dp_shard_in_ep")
126+
dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep")
127+
dp_cp_mesh_dim_names.append("dp_shard_in_ep")
128+
ep_mesh_dim_names.append("dp_shard_in_ep")
129+
if self.cp_enabled:
130+
dp_shard_cp_mesh_dim_names.append("cp")
131+
dp_cp_mesh_dim_names.append("cp")
132+
ep_mesh_dim_names.append("cp")
133+
134+
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
135+
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp")
136+
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
137+
mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep")
138+
139+
return mesh
140+
141+
def _build_mesh_without_ep(
67142
self,
68143
device_type: str,
69-
dims: list[int],
70-
names: list[str],
71144
init_device_mesh_fn: Callable,
72145
) -> DeviceMesh:
146+
dims = []
147+
names = []
148+
for d, name in zip(
149+
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
150+
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
151+
):
152+
if d > 1:
153+
dims.append(d)
154+
names.append(name)
155+
73156
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
74157
mesh = init_device_mesh_fn(device_type, dims, mesh_dim_names=names)
75158

@@ -143,3 +226,12 @@ def loss_parallel_enabled(self):
143226
@cached_property
144227
def non_data_parallel_size(self):
145228
return self.cp * self.tp * self.pp
229+
230+
@property
231+
def ep_enabled(self):
232+
return self.ep > 1
233+
234+
@property
235+
def dense_params_mesh_ndim(self):
236+
# Note: EP params mesh ndim is 1 more due to the 'ep' mesh
237+
return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled

torchtitan/distributed/utils.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def clip_grad_norm_(
307307
error_if_nonfinite: bool = False,
308308
foreach: bool | None = None,
309309
pp_mesh: DeviceMesh | None = None,
310+
parallel_dims: ParallelDims | None = None,
310311
) -> torch.Tensor:
311312
"""
312313
Clip the gradient norm of an iterable of parameters.
@@ -329,11 +330,23 @@ def clip_grad_norm_(
329330
fall back to the slow implementation for other device types.
330331
Default: ``None``
331332
pp_mesh: pipeline parallel device mesh. If not None, will reduce gradient norm across PP stages.
333+
parallel_dims: ParallelDims object which contains Expert Parallel related info.
332334
333335
Returns:
334336
Total norm of the parameter gradients (viewed as a single vector).
335337
336338
"""
339+
if parallel_dims and parallel_dims.ep_enabled:
340+
return _clip_grad_norm_with_ep(
341+
parameters,
342+
max_norm,
343+
norm_type,
344+
error_if_nonfinite,
345+
foreach,
346+
pp_mesh,
347+
parallel_dims,
348+
)
349+
337350
if isinstance(parameters, torch.Tensor):
338351
parameters = [parameters]
339352
else:
@@ -353,7 +366,6 @@ def clip_grad_norm_(
353366
if isinstance(total_norm, DTensor):
354367
# Will reach here if any non-PP parallelism is used.
355368
# If only using PP, total_norm will be a local tensor.
356-
357369
total_norm = total_norm.full_tensor()
358370

359371
if pp_mesh is not None:
@@ -366,3 +378,57 @@ def clip_grad_norm_(
366378

367379
torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
368380
return total_norm
381+
382+
383+
@torch.no_grad()
384+
def _clip_grad_norm_with_ep(
385+
parameters: torch.Tensor | Iterable[torch.Tensor],
386+
max_norm: float,
387+
norm_type: float,
388+
error_if_nonfinite: bool,
389+
foreach: bool | None,
390+
pp_mesh: DeviceMesh | None,
391+
parallel_dims: ParallelDims,
392+
) -> torch.Tensor:
393+
assert parallel_dims.ep_enabled
394+
395+
ep_params = []
396+
non_ep_params = []
397+
ep_grads = []
398+
non_ep_grads = []
399+
400+
for p in parameters:
401+
if p.grad is None:
402+
continue
403+
assert isinstance(p.grad, DTensor)
404+
if p.device_mesh.ndim == parallel_dims.dense_params_mesh_ndim:
405+
non_ep_params.append(p)
406+
non_ep_grads.append(p.grad)
407+
else:
408+
ep_params.append(p)
409+
ep_grads.append(p.grad)
410+
ep_grads_total_norm = torch.nn.utils.get_total_norm(
411+
ep_grads, norm_type, error_if_nonfinite, foreach
412+
).full_tensor()
413+
non_ep_grads_total_norm = torch.nn.utils.get_total_norm(
414+
non_ep_grads, norm_type, error_if_nonfinite, foreach
415+
).full_tensor()
416+
417+
if math.isinf(norm_type):
418+
total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm)
419+
else:
420+
total_norm = (
421+
ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type
422+
)
423+
total_norm **= 1.0 / norm_type
424+
425+
if pp_mesh is not None:
426+
if math.isinf(norm_type):
427+
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group())
428+
else:
429+
total_norm **= norm_type
430+
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
431+
total_norm **= 1.0 / norm_type
432+
433+
torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, foreach)
434+
torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach)

torchtitan/experiments/llama4/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ https://github.com/pytorch/torchtitan/issues/1118
66
#### Available features
77
- Llama 4 model (text-only), including a token-choice MoE architecture with efficient bfloat16 Grouped MM kernels and auxiliary-loss-free load balancing
88
- FSDP, TP, PP, CP support
9+
- Expert Parallel support
910
- DCP checkpoint conversion scripts
1011

1112
#### Download Llama 4 tokenizer
@@ -20,7 +21,6 @@ python scripts/download_tokenizer.py --repo_id meta-llama/Llama-4-Scout-17B-16E
2021
- multimodal support
2122
- Parallelism
2223
- Context Parallel support for FlexAttention and multimodal inputs
23-
- Expert Parallel support
2424
- torch.compile
2525
- for MoE layers
2626
- Quantization

0 commit comments

Comments
 (0)