Skip to content

[DSV3] Apply TP on DSV3 #1341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,5 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l
To create a seed checkpoint, use the same model config as you use for training.
e.g.
```bash
NGPU=1 CONFIG=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1
NGPU=1 CONFIG=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
```
8 changes: 6 additions & 2 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def estimate_memory(job_config: JobConfig):
cp=parallelism_config.context_parallel_degree,
tp=parallelism_config.tensor_parallel_degree,
pp=parallelism_config.pipeline_parallel_degree,
ep=parallelism_config.expert_parallel_degree,
world_size=world_size,
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
)
Expand All @@ -56,8 +57,9 @@ def estimate_memory(job_config: JobConfig):
or parallel_dims.tp_enabled
or parallel_dims.pp_enabled
or parallel_dims.cp_enabled
or parallel_dims.ep_enabled
):
logger.warning("DDP, TP, PP, CP are not supported yet.")
logger.warning("DDP, TP, PP, CP, EP are not supported yet.")
return
if not parallel_dims.dp_shard_enabled:
logger.warning("FSDP or HSDP is not enabled. Skipping memory estimation.")
Expand Down Expand Up @@ -115,7 +117,9 @@ def estimate_memory(job_config: JobConfig):

# build optimizer after applying parallelisms to the model
ft_manager = init_ft_manager(job_config)
optimizers = build_optimizers([model], job_config, ft_manager)
optimizers = build_optimizers(
[model], job_config, parallel_dims, world_mesh, ft_manager
)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
# Post optimizer step model converters hook.
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
Expand Down
1 change: 1 addition & 0 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def test_generate(
cp=1,
tp=world_size,
pp=1,
ep=1,
world_size=world_size,
enable_loss_parallel=False,
)
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/test_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def build_parallel_dims(job_config, world_size):
cp=parallelism_config.context_parallel_degree,
tp=parallelism_config.tensor_parallel_degree,
pp=parallelism_config.pipeline_parallel_degree,
ep=parallelism_config.expert_parallel_degree,
world_size=world_size,
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
)
Expand Down
37 changes: 0 additions & 37 deletions torchtitan/components/ft.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import copy
import importlib
from contextlib import nullcontext
from dataclasses import dataclass
from typing import ContextManager, Optional, TYPE_CHECKING, Union

import torch
Expand All @@ -18,7 +17,6 @@
from torch.distributed.distributed_c10d import ReduceOp
from torch.distributed.tensor import DTensor
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims

if importlib.util.find_spec("torchft") is not None:
import torchft as ft
Expand Down Expand Up @@ -106,41 +104,6 @@ def init_ft_manager(job: JobConfig) -> FTManager:
)


@dataclass
class FTParallelDims(ParallelDims):
ft_manager: FTManager

def build_mesh(self, device_type: str) -> DeviceMesh:
def func(
device_type: str, mesh_shape: list[int], mesh_dim_names: list[str]
) -> DeviceMesh:
from torchft.process_group import ft_init_device_mesh

return ft_init_device_mesh(
device_type=device_type,
mesh_shape=mesh_shape,
mesh_dim_names=mesh_dim_names,
replicate_dim=mesh_dim_names.index("dp_replicate"),
manager=self.ft_manager.manager,
)

dims = []
names = []
for d, name in zip(
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
):
if d > 1 or name == "dp_replicate":
dims.append(d)
names.append(name)

return self._build_mesh(device_type, dims, names, func)

@property
def dp_replicate_enabled(self):
return True


def ft_dist_reduce(
x: torch.Tensor, reduceOp: str, mesh: DeviceMesh
) -> tuple[torch.Tensor, str, DeviceMesh]:
Expand Down
123 changes: 113 additions & 10 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import functools
from itertools import chain
from typing import Any, Generic, Iterator, TypeVar

import torch
Expand All @@ -15,10 +16,13 @@
StateDictOptions,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor
from torch.optim import Optimizer

from torchtitan.components.ft import FTManager, has_torchft
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims

__all__ = [
"OptimizersContainer",
Expand Down Expand Up @@ -238,9 +242,85 @@ def zero_grad(self, *args, **kwargs) -> None:
super().zero_grad(*args, **kwargs)


class ExpertParallelOptimizersContainer(OptimizersContainer):
"""
This class is created to support fused optimizer implementation for Expert Parallel.
Since in EP, not all the parameters are sharded on the same DeviceMesh, the base
OptimizersContainer cannot perform fused optimizer steps on all DTensor parameters.
In this class, we create two optimizers for each model part, one for ep params and the
other for non-ep params. Parameters in the same optimizer are always on the same DeviceMesh,
so that fused optimizer can be performed.
"""

def __init__(
self,
model_parts: list[nn.Module],
optimizer_cls: type[T],
optimizer_kwargs: dict[str, Any],
dense_params_mesh_ndim: int,
) -> None:
ep_params, non_ep_params = [], []
self.ep_optimizers = []
self.non_ep_optimizers = []

self.model_parts = model_parts
# This is still needed to
# 1. reuse other OptimizersContainer's methods other than state dict save / load
# 2. define LR schedulers
self.optimizers = []

for model in self.model_parts:
for p in model.parameters():
if not p.requires_grad:
continue
assert isinstance(p, DTensor)
if p.device_mesh.ndim == dense_params_mesh_ndim:
non_ep_params.append(p)
else:
ep_params.append(p)

ep_optimizer = optimizer_cls(ep_params, **optimizer_kwargs)
non_ep_optimizers = optimizer_cls(non_ep_params, **optimizer_kwargs)
self.ep_optimizers.append(ep_optimizer)
self.non_ep_optimizers.append(non_ep_optimizers)
self.optimizers.append(ep_optimizer)
self.optimizers.append(non_ep_optimizers)

# NOTE: each model part has two optimizers, one for ep params
# and the other for non-ep params
self._validate_length(len(self.model_parts) * 2)
self._post_init(ep_params, optimizer_kwargs)
self._post_init(non_ep_params, optimizer_kwargs)

def state_dict(self) -> dict[str, Any]:
func = functools.partial(
get_optimizer_state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
return {
k: v
for sd in chain(
map(func, self.model_parts, self.ep_optimizers),
map(func, self.model_parts, self.non_ep_optimizers),
)
for k, v in sd.items()
}

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
func = functools.partial(
set_optimizer_state_dict,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
list(map(func, self.model_parts, self.ep_optimizers))
list(map(func, self.model_parts, self.non_ep_optimizers))


def build_optimizers(
model_parts: list[nn.Module],
job_config: JobConfig,
parallel_dims: ParallelDims,
world_mesh: DeviceMesh,
ft_manager: FTManager,
) -> OptimizersContainer:
"""Create a OptimizersContainer for the given model parts and job config.
Expand All @@ -259,12 +339,23 @@ def build_optimizers(
Args:
model_parts (List[nn.Module]): List of model parts to be optimized.
job_config (JobConfig): Job config containing the optimizer name and parameters.
parallel_dims (ParallelDims): Parallel dimensions for the model.
"""
optim_in_bwd = job_config.optimizer.early_step_in_backward
if optim_in_bwd and job_config.parallelism.pipeline_parallel_degree > 1:
raise NotImplementedError(
"Optimizers in backward is not supported with pipeline parallelism."
)
if optim_in_bwd:
if parallel_dims.ep_enabled:
raise NotImplementedError(
"Optimizers in backward is not supported with Expert Parallel."
)
if parallel_dims.pp_enabled:
raise NotImplementedError(
"Optimizers in backward is not supported with Pipeline Parallel."
)
if ft_manager.enabled:
raise NotImplementedError(
"TorchFT is not supported with optimizers in backward."
)

name = job_config.optimizer.name
lr = job_config.optimizer.lr
beta1 = job_config.optimizer.beta1
Expand Down Expand Up @@ -295,19 +386,31 @@ def build_optimizers(
raise NotImplementedError(f"Optimizer {name} not added.")
optimizer_cls = optimizer_classes[name]

if optim_in_bwd and ft_manager.enabled:
raise ValueError("TorchFT is not supported with optimizers in backward.")
elif optim_in_bwd:
if optim_in_bwd:
return OptimizersInBackwardContainer(
model_parts, optimizer_cls, optimizer_kwargs
)
elif ft_manager.enabled:

if ft_manager.enabled:
return FTOptimizersContainer(
model_parts,
optimizer_cls,
optimizer_kwargs,
ft_manager.manager,
use_ft_optimizer=job_config.fault_tolerance.semi_sync_method is None,
)
else:
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)

if parallel_dims.ep_enabled and fused:
if ft_manager.enabled:
raise NotImplementedError(
"Expert Parallel with fused optimizer implementation "
"is not supported with TorchFT yet."
)
return ExpertParallelOptimizersContainer(
model_parts,
optimizer_cls,
optimizer_kwargs,
parallel_dims.dense_params_mesh_ndim,
)

return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
1 change: 1 addition & 0 deletions torchtitan/components/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Optional

from tokenizers import AddedToken, Tokenizer as HfTokenizer

from typing_extensions import override


Expand Down
8 changes: 8 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,14 @@ class Parallelism:
The default value is 'allgather'.
"""

expert_parallel_degree: int = 1
"""
Expert parallelism degree. 1 means disabled.
Currently, only "dp2ep" is supported, with the following constraints:
context_parallel_degree <= expert_parallel_degree <= data_parallel_shard_degree * context_parallel_degree
Note that this is still an experimental feature.
"""


@dataclass
class Checkpoint:
Expand Down
Loading
Loading