Skip to content

dp2ep Expert Parallel #1324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,5 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l
To create a seed checkpoint, use the same model config as you use for training.
e.g.
```bash
NGPU=1 CONFIG=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1
NGPU=1 CONFIG=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
```
2 changes: 1 addition & 1 deletion docs/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ For multiple experimental runs with different parallelism configs, we need to us


```bash
NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1
NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
```

**Note**: Using a seed checkpoint will only make sure a model has same initial weights when configs change, but the training process may not be the same even after setting the seed and the `deterministic` mode, e.g. due to tensor shape change, data precision change, usage of randomness in model code, etc.
Expand Down
8 changes: 6 additions & 2 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def estimate_memory(job_config: JobConfig):
cp=parallelism_config.context_parallel_degree,
tp=parallelism_config.tensor_parallel_degree,
pp=parallelism_config.pipeline_parallel_degree,
ep=parallelism_config.expert_parallel_degree,
world_size=world_size,
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
)
Expand All @@ -56,8 +57,9 @@ def estimate_memory(job_config: JobConfig):
or parallel_dims.tp_enabled
or parallel_dims.pp_enabled
or parallel_dims.cp_enabled
or parallel_dims.ep_enabled
):
logger.warning("DDP, TP, PP, CP are not supported yet.")
logger.warning("DDP, TP, PP, CP, EP are not supported yet.")
return
if not parallel_dims.dp_shard_enabled:
logger.warning("FSDP or HSDP is not enabled. Skipping memory estimation.")
Expand Down Expand Up @@ -115,7 +117,9 @@ def estimate_memory(job_config: JobConfig):

# build optimizer after applying parallelisms to the model
ft_manager = init_ft_manager(job_config)
optimizers = build_optimizers([model], job_config, ft_manager)
optimizers = build_optimizers(
[model], job_config, parallel_dims, world_mesh, ft_manager
)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
# Post optimizer step model converters hook.
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
Expand Down
1 change: 1 addition & 0 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def test_generate(
cp=1,
tp=world_size,
pp=1,
ep=1,
world_size=world_size,
enable_loss_parallel=False,
)
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/test_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def build_parallel_dims(job_config, world_size):
cp=parallelism_config.context_parallel_degree,
tp=parallelism_config.tensor_parallel_degree,
pp=parallelism_config.pipeline_parallel_degree,
ep=parallelism_config.expert_parallel_degree,
world_size=world_size,
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
)
Expand Down
37 changes: 0 additions & 37 deletions torchtitan/components/ft.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import copy
import importlib
from contextlib import nullcontext
from dataclasses import dataclass
from typing import ContextManager, Optional, TYPE_CHECKING, Union

import torch
Expand All @@ -18,7 +17,6 @@
from torch.distributed.distributed_c10d import ReduceOp
from torch.distributed.tensor import DTensor
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims

if importlib.util.find_spec("torchft") is not None:
import torchft as ft
Expand Down Expand Up @@ -106,41 +104,6 @@ def init_ft_manager(job: JobConfig) -> FTManager:
)


@dataclass
class FTParallelDims(ParallelDims):
ft_manager: FTManager

def build_mesh(self, device_type: str) -> DeviceMesh:
def func(
device_type: str, mesh_shape: list[int], mesh_dim_names: list[str]
) -> DeviceMesh:
from torchft.process_group import ft_init_device_mesh

return ft_init_device_mesh(
device_type=device_type,
mesh_shape=mesh_shape,
mesh_dim_names=mesh_dim_names,
replicate_dim=mesh_dim_names.index("dp_replicate"),
manager=self.ft_manager.manager,
)

dims = []
names = []
for d, name in zip(
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
):
if d > 1 or name == "dp_replicate":
dims.append(d)
names.append(name)

return self._build_mesh(device_type, dims, names, func)

@property
def dp_replicate_enabled(self):
return True


def ft_dist_reduce(
x: torch.Tensor, reduceOp: str, mesh: DeviceMesh
) -> tuple[torch.Tensor, str, DeviceMesh]:
Expand Down
34 changes: 24 additions & 10 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
StateDictOptions,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.device_mesh import DeviceMesh
from torch.optim import Optimizer

from torchtitan.components.ft import FTManager, has_torchft
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims

__all__ = [
"OptimizersContainer",
Expand Down Expand Up @@ -241,6 +243,8 @@ def zero_grad(self, *args, **kwargs) -> None:
def build_optimizers(
model_parts: list[nn.Module],
job_config: JobConfig,
parallel_dims: ParallelDims,
world_mesh: DeviceMesh,
ft_manager: FTManager,
) -> OptimizersContainer:
"""Create a OptimizersContainer for the given model parts and job config.
Expand All @@ -259,12 +263,23 @@ def build_optimizers(
Args:
model_parts (List[nn.Module]): List of model parts to be optimized.
job_config (JobConfig): Job config containing the optimizer name and parameters.
parallel_dims (ParallelDims): Parallel dimensions for the model.
"""
optim_in_bwd = job_config.optimizer.early_step_in_backward
if optim_in_bwd and job_config.parallelism.pipeline_parallel_degree > 1:
raise NotImplementedError(
"Optimizers in backward is not supported with pipeline parallelism."
)
if optim_in_bwd:
if parallel_dims.ep_enabled:
raise NotImplementedError(
"Optimizers in backward is not supported with Expert Parallel."
)
if parallel_dims.pp_enabled:
raise NotImplementedError(
"Optimizers in backward is not supported with Pipeline Parallel."
)
if ft_manager.enabled:
raise NotImplementedError(
"TorchFT is not supported with optimizers in backward."
)

name = job_config.optimizer.name
lr = job_config.optimizer.lr
beta1 = job_config.optimizer.beta1
Expand Down Expand Up @@ -295,19 +310,18 @@ def build_optimizers(
raise NotImplementedError(f"Optimizer {name} not added.")
optimizer_cls = optimizer_classes[name]

if optim_in_bwd and ft_manager.enabled:
raise ValueError("TorchFT is not supported with optimizers in backward.")
elif optim_in_bwd:
if optim_in_bwd:
return OptimizersInBackwardContainer(
model_parts, optimizer_cls, optimizer_kwargs
)
elif ft_manager.enabled:

if ft_manager.enabled:
return FTOptimizersContainer(
model_parts,
optimizer_cls,
optimizer_kwargs,
ft_manager.manager,
use_ft_optimizer=job_config.fault_tolerance.semi_sync_method is None,
)
else:
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)

return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
8 changes: 8 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,14 @@ class Parallelism:
The default value is 'allgather'.
"""

expert_parallel_degree: int = 1
"""
Expert parallelism degree. 1 means disabled.
Currently, only "dp2ep" is supported, with the following constraints:
context_parallel_degree <= expert_parallel_degree <= data_parallel_shard_degree * context_parallel_degree
Note that this is still an experimental feature.
"""


@dataclass
class Checkpoint:
Expand Down
104 changes: 91 additions & 13 deletions torchtitan/distributed/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property

Expand All @@ -23,21 +22,23 @@ class ParallelDims:
cp: int
tp: int
pp: int
ep: int
world_size: int
enable_loss_parallel: bool

def __post_init__(self):
self._validate()

def _validate(self):
dp_replicate, dp_shard, cp, tp, pp = (
dp_replicate, dp_shard, cp, tp, pp, ep = (
self.dp_replicate,
self.dp_shard,
self.cp,
self.tp,
self.pp,
self.ep,
)
for d in (dp_replicate, cp, tp, pp):
for d in (dp_replicate, cp, tp, pp, ep):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"

assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
Expand All @@ -50,7 +51,84 @@ def _validate(self):
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)

if ep > 1:
# EP would borrow all cp and some dp_shard degree
assert ep % cp == 0 and (dp_shard * cp) % ep == 0

def build_mesh(self, device_type: str) -> DeviceMesh:
# TODO: Current implementation of ParallelDims for dp2ep Expert Parallel
# is not very clean, due to the limited support from DeviceMesh
# for creating two staggered meshes. Will improve.
if self.ep > 1:
return self._build_mesh_with_ep(device_type)
else:
return self._build_mesh_without_ep(device_type)

def _build_mesh_with_ep(self, device_type: str) -> DeviceMesh:
# With ep, dp_shard and ep are derived submeshes:
# dp_shard = dp_shard_mod_ep * dp_shard_in_ep
# ep = dp_shard_in_ep * cp
dp_shard_mod_ep = self.dp_shard * self.cp // self.ep
dp_shard_in_ep = self.ep // self.cp

dims = []
names = []
for d, name in zip(
[
self.pp,
self.dp_replicate,
dp_shard_mod_ep,
dp_shard_in_ep,
self.cp,
self.tp,
],
["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"],
):
# dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping
# helps the MoE layers do mixed precision training
if d > 1 or name == "dp_shard_mod_ep":
dims.append(d)
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)

# Create all the submesh here to ensure all required process groups are
# initialized:
# Mesh for data loading (no communication on this mesh)
dp_mesh_dim_names = []
# Mesh for param sharding
dp_shard_cp_mesh_dim_names = []
# Mesh for loss all-reduce
dp_cp_mesh_dim_names = []
# Mesh for ep
ep_mesh_dim_names = []

if self.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")
dp_cp_mesh_dim_names.append("dp_replicate")
# dp_shard_mod_ep is always needed, even if it's 1
dp_mesh_dim_names.append("dp_shard_mod_ep")
dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep")
dp_cp_mesh_dim_names.append("dp_shard_mod_ep")
if "dp_shard_in_ep" in names:
dp_mesh_dim_names.append("dp_shard_in_ep")
dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep")
dp_cp_mesh_dim_names.append("dp_shard_in_ep")
ep_mesh_dim_names.append("dp_shard_in_ep")
if self.cp_enabled:
dp_shard_cp_mesh_dim_names.append("cp")
dp_cp_mesh_dim_names.append("cp")
ep_mesh_dim_names.append("cp")

mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp")
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep")

return mesh

def _build_mesh_without_ep(self, device_type: str) -> DeviceMesh:
dims = []
names = []
for d, name in zip(
Expand All @@ -61,17 +139,8 @@ def build_mesh(self, device_type: str) -> DeviceMesh:
dims.append(d)
names.append(name)

return self._build_mesh(device_type, dims, names, init_device_mesh)

def _build_mesh(
self,
device_type: str,
dims: list[int],
names: list[str],
init_device_mesh_fn: Callable,
) -> DeviceMesh:
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
mesh = init_device_mesh_fn(device_type, dims, mesh_dim_names=names)
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)

# Create all the submesh here to ensure all required process groups are
# initialized:
Expand Down Expand Up @@ -143,3 +212,12 @@ def loss_parallel_enabled(self):
@cached_property
def non_data_parallel_size(self):
return self.cp * self.tp * self.pp

@property
def ep_enabled(self):
return self.ep > 1

@property
def dense_params_mesh_ndim(self):
# Note: EP params mesh ndim is 1 more due to the 'ep' mesh
return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled
Loading
Loading