Skip to content

Commit 547ecae

Browse files
committed
[WIP] expert parallel dp2ep
1 parent f4048f8 commit 547ecae

File tree

11 files changed

+602
-197
lines changed

11 files changed

+602
-197
lines changed

run_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ set -ex
1212
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
1313
NGPU=${NGPU:-"8"}
1414
export LOG_RANK=${LOG_RANK:-0}
15-
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
15+
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/llama4/train_configs/debug_model.toml"}
1616

1717
overrides=""
1818
if [ $# -ne 0 ]; then

torchtitan/components/checkpoint.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
LR_SCHEDULER = "lr_scheduler"
4242
DATALOADER = "dataloader"
4343
TRAIN_STATE = "train_state"
44+
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
45+
# temporarily and we don't want to include it in the exported state_dict.
46+
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
47+
excluded_parameters_for_model_only = {"freqs_cis"}
4448

4549

4650
class AsyncMode(str, enum.Enum):
@@ -53,7 +57,10 @@ class ModelWrapper(Stateful):
5357
def __init__(self, model: nn.Module | list[nn.Module]) -> None:
5458
self.model = [model] if isinstance(model, nn.Module) else model
5559
self.cache_state_dict = {
56-
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
60+
k: v
61+
for sd in map(get_model_state_dict, self.model)
62+
for k, v in sd.items()
63+
if k not in excluded_parameters_for_model_only
5764
}
5865

5966
def state_dict(self) -> dict[str, Any]:
@@ -69,7 +76,10 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
6976
# `set_model_state_dict()` does change the keys of the input state_dict,
7077
# we will need to reinitialize the cache_state_dict.
7178
self.cache_state_dict = {
72-
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
79+
k: v
80+
for sd in map(get_model_state_dict, self.model)
81+
for k, v in sd.items()
82+
if k not in excluded_parameters_for_model_only
7383
}
7484

7585

@@ -81,12 +91,6 @@ class SaveDone:
8191
pass
8292

8393

84-
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
85-
# temporarily and we don't want to include it in the exported state_dict.
86-
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
87-
excluded_parameters_for_model_only = {"freqs_cis"}
88-
89-
9094
@torch.no_grad()
9195
def save_with_gc(state, checkpoint_id):
9296
dcp.save(state, checkpoint_id=checkpoint_id)
@@ -568,10 +572,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
568572
"""
569573
# For the first step, we will only load the model weights.
570574
if model_only:
571-
sd = self.states[MODEL].state_dict()
572-
for k in excluded_parameters_for_model_only:
573-
sd.pop(k, None)
574-
return sd
575+
return {MODEL: self.states[MODEL]}
575576

576577
for exclude_key in self.exclude_from_loading:
577578
if exclude_key not in self.states:

torchtitan/config_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,14 @@ class Parallelism:
363363
The default value is 'allgather'.
364364
"""
365365

366+
expert_parallel_degree: int = 1
367+
"""
368+
Expert parallelism degree. 1 means disabled.
369+
Currently, only "dp2ep" is supported, with the following constraints:
370+
context_parallel_degree <= expert_parallel_degree <= data_parallel_shard_degree * context_parallel_degree
371+
Note that this is still an experimental feature.
372+
"""
373+
366374

367375
@dataclass
368376
class Checkpoint:

torchtitan/distributed/parallel_dims.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,23 @@ class ParallelDims:
2323
cp: int
2424
tp: int
2525
pp: int
26+
ep: int
2627
world_size: int
2728
enable_loss_parallel: bool
2829

2930
def __post_init__(self):
3031
self._validate()
3132

3233
def _validate(self):
33-
dp_replicate, dp_shard, cp, tp, pp = (
34+
dp_replicate, dp_shard, cp, tp, pp, ep = (
3435
self.dp_replicate,
3536
self.dp_shard,
3637
self.cp,
3738
self.tp,
3839
self.pp,
40+
self.ep,
3941
)
40-
for d in (dp_replicate, cp, tp, pp):
42+
for d in (dp_replicate, cp, tp, pp, ep):
4143
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
4244

4345
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
@@ -50,7 +52,78 @@ def _validate(self):
5052
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
5153
)
5254

55+
if ep > 1:
56+
# EP would borrow all cp and some dp_shard degree
57+
assert ep % cp == 0 and (dp_shard * cp) % ep == 0
58+
59+
def _build_mesh_with_ep(self, device_type):
60+
# With ep, dp_shard and ep are derived submeshes:
61+
# dp_shard = dp_shard_mod_ep * dp_shard_in_ep
62+
# ep = dp_shard_in_ep * cp
63+
dp_shard_mod_ep = self.dp_shard * self.cp // self.ep
64+
dp_shard_in_ep = self.ep // self.cp
65+
66+
dims = []
67+
names = []
68+
for d, name in zip(
69+
[
70+
self.pp,
71+
self.dp_replicate,
72+
dp_shard_mod_ep,
73+
dp_shard_in_ep,
74+
self.cp,
75+
self.tp,
76+
],
77+
["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"],
78+
):
79+
# dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping
80+
# helps the MoE layers do mixed precision training
81+
if d > 1 or name == "dp_shard_mod_ep":
82+
dims.append(d)
83+
names.append(name)
84+
85+
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
86+
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
87+
88+
# Create all the submesh here to ensure all required process groups are
89+
# initialized:
90+
# Mesh for data loading (no communication on this mesh)
91+
dp_mesh_dim_names = []
92+
# Mesh for param sharding
93+
dp_shard_cp_mesh_dim_names = []
94+
# Mesh for loss all-reduce
95+
dp_cp_mesh_dim_names = []
96+
# Mesh for ep
97+
ep_mesh_dim_names = []
98+
99+
if self.dp_replicate_enabled:
100+
dp_mesh_dim_names.append("dp_replicate")
101+
dp_cp_mesh_dim_names.append("dp_replicate")
102+
# dp_shard_mod_ep is always needed, even if it's 1
103+
dp_mesh_dim_names.append("dp_shard_mod_ep")
104+
dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep")
105+
dp_cp_mesh_dim_names.append("dp_shard_mod_ep")
106+
if "dp_shard_in_ep" in names:
107+
dp_mesh_dim_names.append("dp_shard_in_ep")
108+
dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep")
109+
dp_cp_mesh_dim_names.append("dp_shard_in_ep")
110+
ep_mesh_dim_names.append("dp_shard_in_ep")
111+
if self.cp_enabled:
112+
dp_shard_cp_mesh_dim_names.append("cp")
113+
dp_cp_mesh_dim_names.append("cp")
114+
ep_mesh_dim_names.append("cp")
115+
116+
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
117+
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp")
118+
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
119+
mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep")
120+
121+
return mesh
122+
53123
def build_mesh(self, device_type: str) -> DeviceMesh:
124+
if self.ep > 1:
125+
return self._build_mesh_with_ep(device_type)
126+
54127
dims = []
55128
names = []
56129
for d, name in zip(
@@ -143,3 +216,7 @@ def loss_parallel_enabled(self):
143216
@cached_property
144217
def non_data_parallel_size(self):
145218
return self.cp * self.tp * self.pp
219+
220+
@property
221+
def ep_enabled(self):
222+
return self.ep > 1

0 commit comments

Comments
 (0)