Skip to content

Commit b0dffa1

Browse files
committed
dp2ep Expert Parallel
1 parent 7d5f3cc commit b0dffa1

File tree

20 files changed

+848
-306
lines changed

20 files changed

+848
-306
lines changed

docs/checkpoint.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,5 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l
8383
To create a seed checkpoint, use the same model config as you use for training.
8484
e.g.
8585
```bash
86-
NGPU=1 CONFIG=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1
86+
NGPU=1 CONFIG=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
8787
```

docs/debugging.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ For multiple experimental runs with different parallelism configs, we need to us
100100

101101

102102
```bash
103-
NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1
103+
NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
104104
```
105105

106106
**Note**: Using a seed checkpoint will only make sure a model has same initial weights when configs change, but the training process may not be the same even after setting the seed and the `deterministic` mode, e.g. due to tensor shape change, data precision change, usage of randomness in model code, etc.

scripts/estimate/estimation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def estimate_memory(job_config: JobConfig):
4646
cp=parallelism_config.context_parallel_degree,
4747
tp=parallelism_config.tensor_parallel_degree,
4848
pp=parallelism_config.pipeline_parallel_degree,
49+
ep=parallelism_config.expert_parallel_degree,
4950
world_size=world_size,
5051
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
5152
)
@@ -56,8 +57,9 @@ def estimate_memory(job_config: JobConfig):
5657
or parallel_dims.tp_enabled
5758
or parallel_dims.pp_enabled
5859
or parallel_dims.cp_enabled
60+
or parallel_dims.ep_enabled
5961
):
60-
logger.warning("DDP, TP, PP, CP are not supported yet.")
62+
logger.warning("DDP, TP, PP, CP, EP are not supported yet.")
6163
return
6264
if not parallel_dims.dp_shard_enabled:
6365
logger.warning("FSDP or HSDP is not enabled. Skipping memory estimation.")
@@ -115,7 +117,9 @@ def estimate_memory(job_config: JobConfig):
115117

116118
# build optimizer after applying parallelisms to the model
117119
ft_manager = init_ft_manager(job_config)
118-
optimizers = build_optimizers([model], job_config, ft_manager)
120+
optimizers = build_optimizers(
121+
[model], job_config, parallel_dims, world_mesh, ft_manager
122+
)
119123
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
120124
# Post optimizer step model converters hook.
121125
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2

scripts/generate/test_generate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def test_generate(
125125
cp=1,
126126
tp=world_size,
127127
pp=1,
128+
ep=1,
128129
world_size=world_size,
129130
enable_loss_parallel=False,
130131
)

tests/unit_tests/test_model_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def build_parallel_dims(job_config, world_size):
2121
cp=parallelism_config.context_parallel_degree,
2222
tp=parallelism_config.tensor_parallel_degree,
2323
pp=parallelism_config.pipeline_parallel_degree,
24+
ep=parallelism_config.expert_parallel_degree,
2425
world_size=world_size,
2526
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
2627
)

torchtitan/components/ft.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import copy
88
import importlib
99
from contextlib import nullcontext
10-
from dataclasses import dataclass
1110
from typing import ContextManager, Optional, TYPE_CHECKING, Union
1211

1312
import torch
@@ -18,7 +17,6 @@
1817
from torch.distributed.distributed_c10d import ReduceOp
1918
from torch.distributed.tensor import DTensor
2019
from torchtitan.config_manager import JobConfig
21-
from torchtitan.distributed import ParallelDims
2220

2321
if importlib.util.find_spec("torchft") is not None:
2422
import torchft as ft
@@ -106,41 +104,6 @@ def init_ft_manager(job: JobConfig) -> FTManager:
106104
)
107105

108106

109-
@dataclass
110-
class FTParallelDims(ParallelDims):
111-
ft_manager: FTManager
112-
113-
def build_mesh(self, device_type: str) -> DeviceMesh:
114-
def func(
115-
device_type: str, mesh_shape: list[int], mesh_dim_names: list[str]
116-
) -> DeviceMesh:
117-
from torchft.process_group import ft_init_device_mesh
118-
119-
return ft_init_device_mesh(
120-
device_type=device_type,
121-
mesh_shape=mesh_shape,
122-
mesh_dim_names=mesh_dim_names,
123-
replicate_dim=mesh_dim_names.index("dp_replicate"),
124-
manager=self.ft_manager.manager,
125-
)
126-
127-
dims = []
128-
names = []
129-
for d, name in zip(
130-
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
131-
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
132-
):
133-
if d > 1 or name == "dp_replicate":
134-
dims.append(d)
135-
names.append(name)
136-
137-
return self._build_mesh(device_type, dims, names, func)
138-
139-
@property
140-
def dp_replicate_enabled(self):
141-
return True
142-
143-
144107
def ft_dist_reduce(
145108
x: torch.Tensor, reduceOp: str, mesh: DeviceMesh
146109
) -> tuple[torch.Tensor, str, DeviceMesh]:

torchtitan/components/optimizer.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
StateDictOptions,
1616
)
1717
from torch.distributed.checkpoint.stateful import Stateful
18+
from torch.distributed.device_mesh import DeviceMesh
1819
from torch.optim import Optimizer
1920

2021
from torchtitan.components.ft import FTManager, has_torchft
2122
from torchtitan.config_manager import JobConfig
23+
from torchtitan.distributed import ParallelDims
2224

2325
__all__ = [
2426
"OptimizersContainer",
@@ -241,6 +243,8 @@ def zero_grad(self, *args, **kwargs) -> None:
241243
def build_optimizers(
242244
model_parts: list[nn.Module],
243245
job_config: JobConfig,
246+
parallel_dims: ParallelDims,
247+
world_mesh: DeviceMesh,
244248
ft_manager: FTManager,
245249
) -> OptimizersContainer:
246250
"""Create a OptimizersContainer for the given model parts and job config.
@@ -259,12 +263,23 @@ def build_optimizers(
259263
Args:
260264
model_parts (List[nn.Module]): List of model parts to be optimized.
261265
job_config (JobConfig): Job config containing the optimizer name and parameters.
266+
parallel_dims (ParallelDims): Parallel dimensions for the model.
262267
"""
263268
optim_in_bwd = job_config.optimizer.early_step_in_backward
264-
if optim_in_bwd and job_config.parallelism.pipeline_parallel_degree > 1:
265-
raise NotImplementedError(
266-
"Optimizers in backward is not supported with pipeline parallelism."
267-
)
269+
if optim_in_bwd:
270+
if parallel_dims.ep_enabled:
271+
raise NotImplementedError(
272+
"Optimizers in backward is not supported with Expert Parallel."
273+
)
274+
if parallel_dims.pp_enabled:
275+
raise NotImplementedError(
276+
"Optimizers in backward is not supported with Pipeline Parallel."
277+
)
278+
if ft_manager.enabled:
279+
raise NotImplementedError(
280+
"TorchFT is not supported with optimizers in backward."
281+
)
282+
268283
name = job_config.optimizer.name
269284
lr = job_config.optimizer.lr
270285
beta1 = job_config.optimizer.beta1
@@ -295,19 +310,18 @@ def build_optimizers(
295310
raise NotImplementedError(f"Optimizer {name} not added.")
296311
optimizer_cls = optimizer_classes[name]
297312

298-
if optim_in_bwd and ft_manager.enabled:
299-
raise ValueError("TorchFT is not supported with optimizers in backward.")
300-
elif optim_in_bwd:
313+
if optim_in_bwd:
301314
return OptimizersInBackwardContainer(
302315
model_parts, optimizer_cls, optimizer_kwargs
303316
)
304-
elif ft_manager.enabled:
317+
318+
if ft_manager.enabled:
305319
return FTOptimizersContainer(
306320
model_parts,
307321
optimizer_cls,
308322
optimizer_kwargs,
309323
ft_manager.manager,
310324
use_ft_optimizer=job_config.fault_tolerance.semi_sync_method is None,
311325
)
312-
else:
313-
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
326+
327+
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)

torchtitan/config_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,14 @@ class Parallelism:
363363
The default value is 'allgather'.
364364
"""
365365

366+
expert_parallel_degree: int = 1
367+
"""
368+
Expert parallelism degree. 1 means disabled.
369+
Currently, only "dp2ep" is supported, with the following constraints:
370+
context_parallel_degree <= expert_parallel_degree <= data_parallel_shard_degree * context_parallel_degree
371+
Note that this is still an experimental feature.
372+
"""
373+
366374

367375
@dataclass
368376
class Checkpoint:

torchtitan/distributed/parallel_dims.py

Lines changed: 91 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from collections.abc import Callable
87
from dataclasses import dataclass
98
from functools import cached_property
109

@@ -23,21 +22,23 @@ class ParallelDims:
2322
cp: int
2423
tp: int
2524
pp: int
25+
ep: int
2626
world_size: int
2727
enable_loss_parallel: bool
2828

2929
def __post_init__(self):
3030
self._validate()
3131

3232
def _validate(self):
33-
dp_replicate, dp_shard, cp, tp, pp = (
33+
dp_replicate, dp_shard, cp, tp, pp, ep = (
3434
self.dp_replicate,
3535
self.dp_shard,
3636
self.cp,
3737
self.tp,
3838
self.pp,
39+
self.ep,
3940
)
40-
for d in (dp_replicate, cp, tp, pp):
41+
for d in (dp_replicate, cp, tp, pp, ep):
4142
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
4243

4344
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
@@ -50,7 +51,84 @@ def _validate(self):
5051
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
5152
)
5253

54+
if ep > 1:
55+
# EP would borrow all cp and some dp_shard degree
56+
assert ep % cp == 0 and (dp_shard * cp) % ep == 0
57+
5358
def build_mesh(self, device_type: str) -> DeviceMesh:
59+
# TODO: Current implementation of ParallelDims for dp2ep Expert Parallel
60+
# is not very clean, due to the limited support from DeviceMesh
61+
# for creating two staggered meshes. Will improve.
62+
if self.ep > 1:
63+
return self._build_mesh_with_ep(device_type)
64+
else:
65+
return self._build_mesh_without_ep(device_type)
66+
67+
def _build_mesh_with_ep(self, device_type: str) -> DeviceMesh:
68+
# With ep, dp_shard and ep are derived submeshes:
69+
# dp_shard = dp_shard_mod_ep * dp_shard_in_ep
70+
# ep = dp_shard_in_ep * cp
71+
dp_shard_mod_ep = self.dp_shard * self.cp // self.ep
72+
dp_shard_in_ep = self.ep // self.cp
73+
74+
dims = []
75+
names = []
76+
for d, name in zip(
77+
[
78+
self.pp,
79+
self.dp_replicate,
80+
dp_shard_mod_ep,
81+
dp_shard_in_ep,
82+
self.cp,
83+
self.tp,
84+
],
85+
["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"],
86+
):
87+
# dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping
88+
# helps the MoE layers do mixed precision training
89+
if d > 1 or name == "dp_shard_mod_ep":
90+
dims.append(d)
91+
names.append(name)
92+
93+
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
94+
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
95+
96+
# Create all the submesh here to ensure all required process groups are
97+
# initialized:
98+
# Mesh for data loading (no communication on this mesh)
99+
dp_mesh_dim_names = []
100+
# Mesh for param sharding
101+
dp_shard_cp_mesh_dim_names = []
102+
# Mesh for loss all-reduce
103+
dp_cp_mesh_dim_names = []
104+
# Mesh for ep
105+
ep_mesh_dim_names = []
106+
107+
if self.dp_replicate_enabled:
108+
dp_mesh_dim_names.append("dp_replicate")
109+
dp_cp_mesh_dim_names.append("dp_replicate")
110+
# dp_shard_mod_ep is always needed, even if it's 1
111+
dp_mesh_dim_names.append("dp_shard_mod_ep")
112+
dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep")
113+
dp_cp_mesh_dim_names.append("dp_shard_mod_ep")
114+
if "dp_shard_in_ep" in names:
115+
dp_mesh_dim_names.append("dp_shard_in_ep")
116+
dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep")
117+
dp_cp_mesh_dim_names.append("dp_shard_in_ep")
118+
ep_mesh_dim_names.append("dp_shard_in_ep")
119+
if self.cp_enabled:
120+
dp_shard_cp_mesh_dim_names.append("cp")
121+
dp_cp_mesh_dim_names.append("cp")
122+
ep_mesh_dim_names.append("cp")
123+
124+
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
125+
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp")
126+
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
127+
mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep")
128+
129+
return mesh
130+
131+
def _build_mesh_without_ep(self, device_type: str) -> DeviceMesh:
54132
dims = []
55133
names = []
56134
for d, name in zip(
@@ -61,17 +139,8 @@ def build_mesh(self, device_type: str) -> DeviceMesh:
61139
dims.append(d)
62140
names.append(name)
63141

64-
return self._build_mesh(device_type, dims, names, init_device_mesh)
65-
66-
def _build_mesh(
67-
self,
68-
device_type: str,
69-
dims: list[int],
70-
names: list[str],
71-
init_device_mesh_fn: Callable,
72-
) -> DeviceMesh:
73142
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
74-
mesh = init_device_mesh_fn(device_type, dims, mesh_dim_names=names)
143+
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
75144

76145
# Create all the submesh here to ensure all required process groups are
77146
# initialized:
@@ -143,3 +212,12 @@ def loss_parallel_enabled(self):
143212
@cached_property
144213
def non_data_parallel_size(self):
145214
return self.cp * self.tp * self.pp
215+
216+
@property
217+
def ep_enabled(self):
218+
return self.ep > 1
219+
220+
@property
221+
def dense_params_mesh_ndim(self):
222+
# Note: EP params mesh ndim is 1 more due to the 'ep' mesh
223+
return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled

0 commit comments

Comments
 (0)