From b0dffa16b278749b5ed12024c1223ad65963bf15 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 25 Jun 2025 22:51:48 -0700 Subject: [PATCH] dp2ep Expert Parallel --- docs/checkpoint.md | 2 +- docs/debugging.md | 2 +- scripts/estimate/estimation.py | 8 +- scripts/generate/test_generate.py | 1 + tests/unit_tests/test_model_converter.py | 1 + torchtitan/components/ft.py | 37 --- torchtitan/components/optimizer.py | 34 +- torchtitan/config_manager.py | 8 + torchtitan/distributed/parallel_dims.py | 104 +++++- torchtitan/distributed/utils.py | 70 +++- torchtitan/experiments/llama4/README.md | 2 +- torchtitan/experiments/llama4/__init__.py | 4 +- .../llama4/infra/expert_parallel.py | 249 ++++++++++++--- .../experiments/llama4/infra/parallelize.py | 300 ++++++++++++++---- torchtitan/experiments/llama4/model/moe.py | 249 +++++++-------- torchtitan/experiments/llama4/optimizer.py | 68 ++++ .../llama4/train_configs/debug_model.toml | 1 + torchtitan/models/llama3/infra/parallelize.py | 5 +- torchtitan/protocols/train_spec.py | 5 +- torchtitan/train.py | 4 +- 20 files changed, 848 insertions(+), 306 deletions(-) create mode 100644 torchtitan/experiments/llama4/optimizer.py diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 0dad44e67..5275db1a2 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -83,5 +83,5 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l To create a seed checkpoint, use the same model config as you use for training. e.g. ```bash -NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 +NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 ``` diff --git a/docs/debugging.md b/docs/debugging.md index f5a619520..61795c70e 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -100,7 +100,7 @@ For multiple experimental runs with different parallelism configs, we need to us ```bash -NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 +NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 ``` **Note**: Using a seed checkpoint will only make sure a model has same initial weights when configs change, but the training process may not be the same even after setting the seed and the `deterministic` mode, e.g. due to tensor shape change, data precision change, usage of randomness in model code, etc. diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 9dac8f17b..33d4dc17f 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -46,6 +46,7 @@ def estimate_memory(job_config: JobConfig): cp=parallelism_config.context_parallel_degree, tp=parallelism_config.tensor_parallel_degree, pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, world_size=world_size, enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) @@ -56,8 +57,9 @@ def estimate_memory(job_config: JobConfig): or parallel_dims.tp_enabled or parallel_dims.pp_enabled or parallel_dims.cp_enabled + or parallel_dims.ep_enabled ): - logger.warning("DDP, TP, PP, CP are not supported yet.") + logger.warning("DDP, TP, PP, CP, EP are not supported yet.") return if not parallel_dims.dp_shard_enabled: logger.warning("FSDP or HSDP is not enabled. Skipping memory estimation.") @@ -115,7 +117,9 @@ def estimate_memory(job_config: JobConfig): # build optimizer after applying parallelisms to the model ft_manager = init_ft_manager(job_config) - optimizers = build_optimizers([model], job_config, ft_manager) + optimizers = build_optimizers( + [model], job_config, parallel_dims, world_mesh, ft_manager + ) lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 0a1649ea4..b6e2e45dc 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -125,6 +125,7 @@ def test_generate( cp=1, tp=world_size, pp=1, + ep=1, world_size=world_size, enable_loss_parallel=False, ) diff --git a/tests/unit_tests/test_model_converter.py b/tests/unit_tests/test_model_converter.py index bb5a1ecc4..704e81a91 100644 --- a/tests/unit_tests/test_model_converter.py +++ b/tests/unit_tests/test_model_converter.py @@ -21,6 +21,7 @@ def build_parallel_dims(job_config, world_size): cp=parallelism_config.context_parallel_degree, tp=parallelism_config.tensor_parallel_degree, pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, world_size=world_size, enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) diff --git a/torchtitan/components/ft.py b/torchtitan/components/ft.py index 11f13df11..946fc4638 100644 --- a/torchtitan/components/ft.py +++ b/torchtitan/components/ft.py @@ -7,7 +7,6 @@ import copy import importlib from contextlib import nullcontext -from dataclasses import dataclass from typing import ContextManager, Optional, TYPE_CHECKING, Union import torch @@ -18,7 +17,6 @@ from torch.distributed.distributed_c10d import ReduceOp from torch.distributed.tensor import DTensor from torchtitan.config_manager import JobConfig -from torchtitan.distributed import ParallelDims if importlib.util.find_spec("torchft") is not None: import torchft as ft @@ -106,41 +104,6 @@ def init_ft_manager(job: JobConfig) -> FTManager: ) -@dataclass -class FTParallelDims(ParallelDims): - ft_manager: FTManager - - def build_mesh(self, device_type: str) -> DeviceMesh: - def func( - device_type: str, mesh_shape: list[int], mesh_dim_names: list[str] - ) -> DeviceMesh: - from torchft.process_group import ft_init_device_mesh - - return ft_init_device_mesh( - device_type=device_type, - mesh_shape=mesh_shape, - mesh_dim_names=mesh_dim_names, - replicate_dim=mesh_dim_names.index("dp_replicate"), - manager=self.ft_manager.manager, - ) - - dims = [] - names = [] - for d, name in zip( - [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], - ["pp", "dp_replicate", "dp_shard", "cp", "tp"], - ): - if d > 1 or name == "dp_replicate": - dims.append(d) - names.append(name) - - return self._build_mesh(device_type, dims, names, func) - - @property - def dp_replicate_enabled(self): - return True - - def ft_dist_reduce( x: torch.Tensor, reduceOp: str, mesh: DeviceMesh ) -> tuple[torch.Tensor, str, DeviceMesh]: diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index a88f33fa1..cd3604f29 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -15,10 +15,12 @@ StateDictOptions, ) from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.device_mesh import DeviceMesh from torch.optim import Optimizer from torchtitan.components.ft import FTManager, has_torchft from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims __all__ = [ "OptimizersContainer", @@ -241,6 +243,8 @@ def zero_grad(self, *args, **kwargs) -> None: def build_optimizers( model_parts: list[nn.Module], job_config: JobConfig, + parallel_dims: ParallelDims, + world_mesh: DeviceMesh, ft_manager: FTManager, ) -> OptimizersContainer: """Create a OptimizersContainer for the given model parts and job config. @@ -259,12 +263,23 @@ def build_optimizers( Args: model_parts (List[nn.Module]): List of model parts to be optimized. job_config (JobConfig): Job config containing the optimizer name and parameters. + parallel_dims (ParallelDims): Parallel dimensions for the model. """ optim_in_bwd = job_config.optimizer.early_step_in_backward - if optim_in_bwd and job_config.parallelism.pipeline_parallel_degree > 1: - raise NotImplementedError( - "Optimizers in backward is not supported with pipeline parallelism." - ) + if optim_in_bwd: + if parallel_dims.ep_enabled: + raise NotImplementedError( + "Optimizers in backward is not supported with Expert Parallel." + ) + if parallel_dims.pp_enabled: + raise NotImplementedError( + "Optimizers in backward is not supported with Pipeline Parallel." + ) + if ft_manager.enabled: + raise NotImplementedError( + "TorchFT is not supported with optimizers in backward." + ) + name = job_config.optimizer.name lr = job_config.optimizer.lr beta1 = job_config.optimizer.beta1 @@ -295,13 +310,12 @@ def build_optimizers( raise NotImplementedError(f"Optimizer {name} not added.") optimizer_cls = optimizer_classes[name] - if optim_in_bwd and ft_manager.enabled: - raise ValueError("TorchFT is not supported with optimizers in backward.") - elif optim_in_bwd: + if optim_in_bwd: return OptimizersInBackwardContainer( model_parts, optimizer_cls, optimizer_kwargs ) - elif ft_manager.enabled: + + if ft_manager.enabled: return FTOptimizersContainer( model_parts, optimizer_cls, @@ -309,5 +323,5 @@ def build_optimizers( ft_manager.manager, use_ft_optimizer=job_config.fault_tolerance.semi_sync_method is None, ) - else: - return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) + + return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 3f8d25688..d40a5982f 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -363,6 +363,14 @@ class Parallelism: The default value is 'allgather'. """ + expert_parallel_degree: int = 1 + """ + Expert parallelism degree. 1 means disabled. + Currently, only "dp2ep" is supported, with the following constraints: + context_parallel_degree <= expert_parallel_degree <= data_parallel_shard_degree * context_parallel_degree + Note that this is still an experimental feature. + """ + @dataclass class Checkpoint: diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 07a056fc4..08986b220 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from collections.abc import Callable from dataclasses import dataclass from functools import cached_property @@ -23,6 +22,7 @@ class ParallelDims: cp: int tp: int pp: int + ep: int world_size: int enable_loss_parallel: bool @@ -30,14 +30,15 @@ def __post_init__(self): self._validate() def _validate(self): - dp_replicate, dp_shard, cp, tp, pp = ( + dp_replicate, dp_shard, cp, tp, pp, ep = ( self.dp_replicate, self.dp_shard, self.cp, self.tp, self.pp, + self.ep, ) - for d in (dp_replicate, cp, tp, pp): + for d in (dp_replicate, cp, tp, pp, ep): assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." @@ -50,7 +51,84 @@ def _validate(self): f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" ) + if ep > 1: + # EP would borrow all cp and some dp_shard degree + assert ep % cp == 0 and (dp_shard * cp) % ep == 0 + def build_mesh(self, device_type: str) -> DeviceMesh: + # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel + # is not very clean, due to the limited support from DeviceMesh + # for creating two staggered meshes. Will improve. + if self.ep > 1: + return self._build_mesh_with_ep(device_type) + else: + return self._build_mesh_without_ep(device_type) + + def _build_mesh_with_ep(self, device_type: str) -> DeviceMesh: + # With ep, dp_shard and ep are derived submeshes: + # dp_shard = dp_shard_mod_ep * dp_shard_in_ep + # ep = dp_shard_in_ep * cp + dp_shard_mod_ep = self.dp_shard * self.cp // self.ep + dp_shard_in_ep = self.ep // self.cp + + dims = [] + names = [] + for d, name in zip( + [ + self.pp, + self.dp_replicate, + dp_shard_mod_ep, + dp_shard_in_ep, + self.cp, + self.tp, + ], + ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], + ): + # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping + # helps the MoE layers do mixed precision training + if d > 1 or name == "dp_shard_mod_ep": + dims.append(d) + names.append(name) + + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + + # Create all the submesh here to ensure all required process groups are + # initialized: + # Mesh for data loading (no communication on this mesh) + dp_mesh_dim_names = [] + # Mesh for param sharding + dp_shard_cp_mesh_dim_names = [] + # Mesh for loss all-reduce + dp_cp_mesh_dim_names = [] + # Mesh for ep + ep_mesh_dim_names = [] + + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + dp_cp_mesh_dim_names.append("dp_replicate") + # dp_shard_mod_ep is always needed, even if it's 1 + dp_mesh_dim_names.append("dp_shard_mod_ep") + dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep") + dp_cp_mesh_dim_names.append("dp_shard_mod_ep") + if "dp_shard_in_ep" in names: + dp_mesh_dim_names.append("dp_shard_in_ep") + dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep") + dp_cp_mesh_dim_names.append("dp_shard_in_ep") + ep_mesh_dim_names.append("dp_shard_in_ep") + if self.cp_enabled: + dp_shard_cp_mesh_dim_names.append("cp") + dp_cp_mesh_dim_names.append("cp") + ep_mesh_dim_names.append("cp") + + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") + mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") + mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") + + return mesh + + def _build_mesh_without_ep(self, device_type: str) -> DeviceMesh: dims = [] names = [] for d, name in zip( @@ -61,17 +139,8 @@ def build_mesh(self, device_type: str) -> DeviceMesh: dims.append(d) names.append(name) - return self._build_mesh(device_type, dims, names, init_device_mesh) - - def _build_mesh( - self, - device_type: str, - dims: list[int], - names: list[str], - init_device_mesh_fn: Callable, - ) -> DeviceMesh: logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh_fn(device_type, dims, mesh_dim_names=names) + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) # Create all the submesh here to ensure all required process groups are # initialized: @@ -143,3 +212,12 @@ def loss_parallel_enabled(self): @cached_property def non_data_parallel_size(self): return self.cp * self.tp * self.pp + + @property + def ep_enabled(self): + return self.ep > 1 + + @property + def dense_params_mesh_ndim(self): + # Note: EP params mesh ndim is 1 more due to the 'ep' mesh + return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 9c86fba1b..3f824d5fe 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -307,6 +307,7 @@ def clip_grad_norm_( error_if_nonfinite: bool = False, foreach: bool | None = None, pp_mesh: DeviceMesh | None = None, + parallel_dims: ParallelDims | None = None, ) -> torch.Tensor: """ Clip the gradient norm of an iterable of parameters. @@ -329,11 +330,23 @@ def clip_grad_norm_( fall back to the slow implementation for other device types. Default: ``None`` pp_mesh: pipeline parallel device mesh. If not None, will reduce gradient norm across PP stages. + parallel_dims: ParallelDims object which contains Expert Parallel related info. Returns: Total norm of the parameter gradients (viewed as a single vector). """ + if parallel_dims and parallel_dims.ep_enabled: + return _clip_grad_norm_with_ep( + parameters, + max_norm, + norm_type, + error_if_nonfinite, + foreach, + pp_mesh, + parallel_dims, + ) + if isinstance(parameters, torch.Tensor): parameters = [parameters] else: @@ -353,7 +366,6 @@ def clip_grad_norm_( if isinstance(total_norm, DTensor): # Will reach here if any non-PP parallelism is used. # If only using PP, total_norm will be a local tensor. - total_norm = total_norm.full_tensor() if pp_mesh is not None: @@ -366,3 +378,59 @@ def clip_grad_norm_( torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) return total_norm + + +@torch.no_grad() +def _clip_grad_norm_with_ep( + parameters: torch.Tensor | Iterable[torch.Tensor], + max_norm: float, + norm_type: float, + error_if_nonfinite: bool, + foreach: bool | None, + pp_mesh: DeviceMesh | None, + parallel_dims: ParallelDims, +) -> torch.Tensor: + assert parallel_dims.ep_enabled + + ep_params = [] + non_ep_params = [] + ep_grads = [] + non_ep_grads = [] + + for p in parameters: + if p.grad is None: + continue + assert isinstance(p, DTensor) and isinstance(p.grad, DTensor) + if p.device_mesh.ndim == parallel_dims.dense_params_mesh_ndim: + non_ep_params.append(p) + non_ep_grads.append(p.grad) + else: + ep_params.append(p) + ep_grads.append(p.grad) + ep_grads_total_norm = torch.nn.utils.get_total_norm( + ep_grads, norm_type, error_if_nonfinite, foreach + ).full_tensor() + non_ep_grads_total_norm = torch.nn.utils.get_total_norm( + non_ep_grads, norm_type, error_if_nonfinite, foreach + ).full_tensor() + + if math.isinf(norm_type): + total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm) + else: + total_norm = ( + ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type + ) + total_norm **= 1.0 / norm_type + + if pp_mesh is not None: + if math.isinf(norm_type): + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) + else: + total_norm **= norm_type + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) + total_norm **= 1.0 / norm_type + + torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, foreach) + torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach) + + return total_norm diff --git a/torchtitan/experiments/llama4/README.md b/torchtitan/experiments/llama4/README.md index eb9adbb31..4b42f7c3f 100644 --- a/torchtitan/experiments/llama4/README.md +++ b/torchtitan/experiments/llama4/README.md @@ -6,6 +6,7 @@ https://github.com/pytorch/torchtitan/issues/1118 #### Available features - Llama 4 model (text-only), including a token-choice MoE architecture with efficient bfloat16 Grouped MM kernels and auxiliary-loss-free load balancing - FSDP, TP, PP, CP support +- Expert Parallel support - DCP checkpoint conversion scripts #### Download Llama 4 tokenizer @@ -20,7 +21,6 @@ python scripts/download_tokenizer.py --repo_id meta-llama/Llama-4-Scout-17B-16E - multimodal support - Parallelism - Context Parallel support for FlexAttention and multimodal inputs - - Expert Parallel support - torch.compile - for MoE layers - Quantization diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py index a3fb29375..329c4e9d7 100644 --- a/torchtitan/experiments/llama4/__init__.py +++ b/torchtitan/experiments/llama4/__init__.py @@ -6,7 +6,6 @@ from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers -from torchtitan.components.optimizer import build_optimizers from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer from torchtitan.models.llama3 import pipeline_llama @@ -15,6 +14,7 @@ from .infra.parallelize import parallelize_llama from .model.args import TransformerModelArgs from .model.model import Transformer +from .optimizer import build_llama4_optimizers __all__ = [ "TransformerModelArgs", @@ -98,7 +98,7 @@ config=llama4_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, - build_optimizers_fn=build_optimizers, + build_optimizers_fn=build_llama4_optimizers, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_tiktoken_tokenizer, diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py index 68f2e7a75..0e8aef8ee 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -6,10 +6,12 @@ from functools import partial -from typing import Optional, Tuple +from typing import Callable import torch +import torch.distributed as dist import torch.nn as nn +from torch.distributed._functional_collectives import all_to_all_single_autograd from torch.distributed.tensor import ( DeviceMesh, distribute_module, @@ -24,40 +26,6 @@ # implementation of Tensor Parallel for the GroupedExperts in MoE class TensorParallel(ParallelStyle): - def __init__( - self, - *, - input_layouts: Optional[Tuple[Optional[Placement]]] = None, - output_layout: Optional[Placement] = None, - use_local_output: bool = True, - ): - super().__init__() - self.input_layouts = input_layouts or (Replicate(), Replicate()) - self.output_layout = output_layout or Replicate() - self.desired_input_layouts = (Replicate(), Replicate()) - self.use_local_output = use_local_output - - @staticmethod - def _prepare_input_fn( - input_layouts, desired_input_layouts, mod, inputs, device_mesh - ): - prepared_inputs = [] - # annotate module input placements/sharding with input_layouts - for inp, input_layout, desired_input_layout in zip( - inputs, input_layouts, desired_input_layouts - ): - if isinstance(inp, torch.Tensor): - if not isinstance(inp, DTensor): - inp = DTensor.from_local( - inp, device_mesh, (input_layout,), run_check=False - ) - if input_layout != desired_input_layout: - inp = inp.redistribute( - placements=(desired_input_layout,), async_op=True - ) - prepared_inputs.append(inp) - return tuple(prepared_inputs) - def _partition_fn(self, name, module, device_mesh): module.register_parameter( "w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)])) @@ -71,36 +39,25 @@ def _partition_fn(self, name, module, device_mesh): nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])), ) # Column-wise sharding - @staticmethod - def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): - if outputs.placements != (output_layout,): - outputs = outputs.redistribute(placements=(output_layout,), async_op=True) - # back to local tensor - return outputs.to_local() if use_local_output else outputs - def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: return distribute_module( module, device_mesh, self._partition_fn, - partial( - self._prepare_input_fn, self.input_layouts, self.desired_input_layouts - ), - partial(self._prepare_output_fn, self.output_layout, self.use_local_output), ) # NOTE: This is to achieve replicate computation on the gate module in the MoE router. # It does nothing other than (1) setting the module parameters as DTensors on the given mesh # and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back. -# TODO: The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, +# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, # which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation. class NoParallel(ParallelStyle): def __init__( self, *, - input_layout: Optional[Placement] = None, - output_layout: Optional[Placement] = None, + input_layout: Placement | None = None, + output_layout: Placement | None = None, use_local_output: bool = True, ): super().__init__() @@ -141,3 +98,197 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ), partial(self._prepare_output_fn, self.output_layout, self.use_local_output), ) + + +class ExpertParallel(ParallelStyle): + def __init__(self): + super().__init__() + self.input_splits = None + self.output_splits = None + + # performing all-to-all dispatch on the input + def _token_dispatch(self, mod, inputs, device_mesh): + # annotate module input placements/sharding with input_layouts + routed_input, num_tokens_per_expert = inputs + + # generate the input splits and output splits for all-to-all + with torch.no_grad(): + num_tokens_per_expert_group = num_tokens_per_expert.new_empty( + num_tokens_per_expert.shape[0] + ) + dist.all_to_all_single( + num_tokens_per_expert_group, + num_tokens_per_expert, + group=device_mesh.get_group(), + ) + # NOTE: this would incur a device-to-host sync + self.input_splits = ( + num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1).tolist() + ) + self.output_splits = ( + num_tokens_per_expert_group.view(device_mesh.shape[0], -1) + .sum(dim=1) + .tolist() + ) + + # perform all-to-all + routed_input = all_to_all_single_autograd( + routed_input, + self.output_splits, + self.input_splits, + device_mesh.get_group(), + ) + + # NOTE: After this all-to-all, the routed input is put on proper EP rank. + # However, the num_tokens_per_expert_group is not of the final target format + # [#tokens for local expert 0, #tokens for local expert 1, ...] + # Rather, it is of the format + # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., + # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] + # We need to perform another shuffle to get the correct format -- this is done via the function + # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens + # each expert gets locally is a multiple of ALIGN_SIZE_M. + + return routed_input, num_tokens_per_expert_group + + @staticmethod + def _partition_fn(name, mod, device_mesh): + # shard on the expert dimension + for name, param in mod.named_parameters(recurse=False): + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) + mod.register_parameter(name, dist_param) + + # performing all-to-all combine on the output + def _token_combine(self, mod, routed_output, device_mesh): + routed_output = all_to_all_single_autograd( + routed_output, + self.input_splits, + self.output_splits, + device_mesh.get_group(), + ) + return routed_output + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=ExpertParallel._partition_fn, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + + +# This class is for dp2ep with TP (without TP we can just use ExpertParallel) +class ExpertTensorParallel(ExpertParallel): + def __init__( + self, + tp_mesh: DeviceMesh, + ep_mesh: DeviceMesh, + ): + super().__init__() + # TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh, + # as DeviceMesh doesn't support slicing from a submesh. + self.tp_mesh = tp_mesh + self.ep_mesh = ep_mesh + + def _token_dispatch(self, mod, inputs, device_mesh): + # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh + return super()._token_dispatch(mod, inputs, self.ep_mesh) + + def _partition_fn_2d(self, name, mod, ep_tp_mesh): + mod.register_parameter( + "w1", + nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(2)])), + ) # Column-wise sharding + mod.register_parameter( + "w2", + nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(1)])), + ) # Row-wise sharding + mod.register_parameter( + "w3", + nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(2)])), + ) # Column-wise sharding + + def _token_combine(self, mod, routed_output, device_mesh): + # token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh + return super()._token_combine(mod, routed_output, self.ep_mesh) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=self._partition_fn_2d, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + + +def expert_parallel(func: Callable) -> Callable: + """ + This is a wrapper applied to the GroupedExperts computation, serving + the following three purposes: + 1. Convert parameters from DTensors to plain Tensors, to work with + dynamic-shape inputs which cannot be easily expressed as DTensors. + 2. In Expert Parallel, apply the generate_permute_indices kernel to + permute the inputs to be ordered by local experts (see the _token_dispatch + function in ExpertParallel) and permute the outputs back. + 3. In order to use torch._grouped_mm, we need to make sure the number of + tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices + kernel also helps achieve this via padding, without incurring synchronization + between device and host. Note that this will create side effects when wrapping + the for-loop implementation of GroupedExperts, as it does not need padding. + + Among the above: + 1 and 2 are needed only when expert_parallel_degree > 1. + 3 is needed even for single-device computation. + 2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. + """ + + def wrapper( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(w1, DTensor): + w1 = w1.to_local() + w2 = w2.to_local() + w3 = w3.to_local() + + if num_tokens_per_expert is not None: + from torchtitan.experiments.kernels.moe.indices import ( + generate_permute_indices, + ) + + experts_per_ep_rank = w1.shape[0] + num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank + + ALIGN_SIZE_M = 16 + with torch.no_grad(): + ( + permuted_indices, + num_tokens_per_expert, + _, # offsets, + ) = generate_permute_indices( + num_tokens_per_expert, + experts_per_ep_rank, + num_ep_ranks, + x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M, + ALIGN_SIZE_M, + ) + + x = torch.vstack((x, x.new_zeros((x.shape[-1])))) + input_shape = x.shape + x = x[permuted_indices, :] + + out = func(w1, w2, w3, x, num_tokens_per_expert) + + if num_tokens_per_expert is not None: + out_unpermuted = out.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = out + out = out_unpermuted[:-1] + + return out + + return wrapper diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index c8a968b1c..d681cd6a1 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -8,7 +8,16 @@ import torch import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh - +from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.tensor import Partial, Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + PrepareModuleInputOutput, + RowwiseParallel, + SequenceParallel, +) from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims @@ -16,12 +25,15 @@ apply_ac, apply_compile, apply_ddp, - apply_fsdp, - apply_tp, ) from torchtitan.tools.logging import logger -from ..model.moe import MoE +from .expert_parallel import ( + ExpertParallel, + ExpertTensorParallel, + NoParallel, + TensorParallel, +) def parallelize_llama( @@ -56,7 +68,7 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise - apply_tp( + apply_non_moe_tp( model, world_mesh["tp"], loss_parallel=parallel_dims.loss_parallel_enabled, @@ -64,7 +76,18 @@ def parallelize_llama( enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) - apply_moe_tp(model, world_mesh["tp"]) + # TODO: shall we support tensorwise float8 comms for MoE TP + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.tp_enabled and parallel_dims.ep_enabled + else None + ), + ) if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) @@ -77,15 +100,21 @@ def parallelize_llama( torch._dynamo.config.capture_scalar_outputs = True dp_mesh: DeviceMesh | None = None - if ( - parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled - ): # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: + # apply FSDP or HSDP, potentially with Context Parallel if parallel_dims.dp_replicate_enabled: dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") else: dp_mesh_dim_names = ("dp_shard_cp",) dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + dp_mod_ep_mesh_dim_names = [] + if parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + apply_fsdp( model, dp_mesh, @@ -94,6 +123,11 @@ def parallelize_llama( pp_enabled=parallel_dims.pp_enabled, cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + dp_mod_ep_mesh=( + world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + if dp_mod_ep_mesh_dim_names + else None + ), ) if parallel_dims.dp_replicate_enabled: @@ -117,64 +151,222 @@ def parallelize_llama( enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) - # for MoE auxiliary-loss-free load balancing - if parallel_dims.dp_cp_enabled: - # NOTE: Currently this sync is blocking (thus exposed) and happens on the - # default compute stream. Need to assess if this is OK performance-wise. - dp_cp_mesh = world_mesh["dp_cp"] - - def _sync_tokens_per_expert(module, *_): - assert isinstance(module, MoE) - torch.distributed.all_reduce( - module.tokens_per_expert, group=dp_cp_mesh.get_group() - ) - - for transformer_block in model.layers.values(): - if transformer_block.moe_enabled: - load_balance_coeff = transformer_block.moe.load_balance_coeff - if load_balance_coeff is not None and load_balance_coeff > 0: - # prepend=True so that the sync runs before - # the _update_expert_bias hook in MoE - transformer_block.moe.register_full_backward_hook( - _sync_tokens_per_expert, prepend=True - ) - else: - break - return model -def apply_moe_tp( +def apply_non_moe_tp( model: nn.Module, tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + enable_async_tp: bool, ): - from torch.distributed.tensor import Partial, Replicate, Shard - from torch.distributed.tensor.parallel import ( - parallelize_module, - PrepareModuleInputOutput, + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, ) - from .expert_parallel import NoParallel, TensorParallel + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears with tensorwise scaling. + if enable_float8_tensorwise_tp: + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block for transformer_block in model.layers.values(): - moe_layer_plan = { - # input / output sharding on the seqlen dim - # all-gather for input, reduce-scatter for output - "moe": PrepareModuleInputOutput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - use_local_input=True, - output_layouts=(Partial(),), - desired_output_layouts=(Shard(1),), + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), ), - # replicate computation for the router - "moe.router.gate": NoParallel(), - # input Replicate, output Partial - "moe.experts": TensorParallel(output_layout=Partial()), - "moe.shared_expert": TensorParallel(output_layout=Partial()), + "attention.wq": colwise_parallel(), + "attention.wk": colwise_parallel(), + "attention.wv": colwise_parallel(), + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), } + if not transformer_block.moe_enabled: + layer_plan.update( + { + "feed_forward": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": colwise_parallel(), + "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel(), + } + ) + parallelize_module( module=transformer_block, device_mesh=tp_mesh, - parallelize_plan=moe_layer_plan, + parallelize_plan=layer_plan, + ) + + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", + dp_mod_ep_mesh: DeviceMesh | None = None, +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + for layer_id, transformer_block in model.layers.items(): + if reshard_after_forward_policy == "always": + reshard_after_forward = True + elif reshard_after_forward_policy == "never": + reshard_after_forward = False + elif reshard_after_forward_policy == "default": + if pp_enabled: + # For PP, do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = False + else: + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = int(layer_id) < len(model.layers) - 1 + else: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + + # NOTE: in an MoE layer, the router and the shared experts + # are sharded together with the TransformerBlock + if transformer_block.moe_enabled and dp_mod_ep_mesh: + fsdp_mod_ep_config = fsdp_config.copy() + fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + fully_shard( + transformer_block.moe.experts, + **fsdp_mod_ep_config, + reshard_after_forward=reshard_after_forward, + ) + + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + + +def apply_moe_ep_tp( + model: nn.Module, + tp_mesh: DeviceMesh | None, + ep_mesh: DeviceMesh | None, + ep_tp_mesh: DeviceMesh | None, +): + for transformer_block in model.layers.values(): + if not transformer_block.moe_enabled: + continue + + if tp_mesh is not None: + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + # input Replicate, output Partial + "moe.shared_expert": TensorParallel(), + } + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) + + # if ep_mesh is not None: + experts_mesh, experts_plan = None, None + if ep_mesh is None: + experts_mesh = tp_mesh + # input Replicate, output Partial + experts_plan = TensorParallel() + elif tp_mesh is None: + experts_mesh = ep_mesh + # input / output sharding on the batch / tokens dim + experts_plan = ExpertParallel() + else: + experts_mesh = ep_tp_mesh + experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) + parallelize_module( + module=transformer_block.moe.experts, + device_mesh=experts_mesh, + parallelize_plan=experts_plan, ) diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index a07bf0f7b..d7f0ce3fd 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -8,6 +8,8 @@ import torch.nn.functional as F from torch import nn +from ..infra.expert_parallel import expert_parallel + from .args import TransformerModelArgs @@ -29,49 +31,73 @@ def __init__( def forward( self, x: torch.Tensor, - num_local_tokens_per_expert: torch.Tensor | list[int] | None = None, + num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: - # TODO: keeping this for loop implementation for comparison - # and readability, will remove later - if not self.use_grouped_mm: - if num_local_tokens_per_expert is not None: - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x, - split_size_or_sections=num_local_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - w1, w2, w3 = ( - self.w1[expert_idx], - self.w2[expert_idx], - self.w3[expert_idx], - ) - h = F.silu(torch.matmul(x_expert, w1)) - h = h * torch.matmul(x_expert, w3) - h = torch.matmul(h, w2) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - else: - # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, self.w1)) - h = h * torch.bmm(x, self.w3) - # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, self.w2) - - return out - - # grouped mm implementation - if num_local_tokens_per_expert is not None: - # https://github.com/pytorch/pytorch/pull/150374 - # NOTE: torch._gouped_mm requires bf16 dtypes - # and shapes to be multiple of 8 - offsets = torch.cumsum( - num_local_tokens_per_expert, dim=0, dtype=torch.int32 + if self.use_grouped_mm: + return GroupedExperts._run_experts_grouped_mm( + self.w1, self.w2, self.w3, x, num_tokens_per_expert + ) + else: + return GroupedExperts._run_experts_for_loop( + self.w1, self.w2, self.w3, x, num_tokens_per_expert + ) + + # TODO: keeping this for-loop implementation for comparison + # and readability, may remove later + @expert_parallel + @staticmethod + def _run_experts_for_loop( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = F.silu(torch.matmul(x_expert, w1[expert_idx])) + h = h * torch.matmul(x_expert, w3[expert_idx]) + h = torch.matmul(h, w2[expert_idx]) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, w1)) + h = h * torch.bmm(x, w3) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, w2) + + return out + + @expert_parallel + @staticmethod + def _run_experts_grouped_mm( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) # grouped mm between a 2D tensor and a 3D tensor assert x.dim() == 2 else: @@ -79,12 +105,9 @@ def forward( # fall back to regular bmm between 3D tensors assert x.dim() == 3 - assert ( - x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16 - ), "torch._grouped_mm only supports bf16 dtypes" - h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) - h = h * torch._grouped_mm(x, self.w3, offs=offsets) - out = torch._grouped_mm(h, self.w2, offs=offsets) + h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) + h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) + out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) return out @@ -120,7 +143,7 @@ def __init__( self.use_sigmoid = use_sigmoid def forward( - self, x: torch.Tensor, expert_bias: torch.Tensor = None + self, x: torch.Tensor, expert_bias: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: @@ -131,7 +154,7 @@ def forward( Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. token_indices (torch.Tensor): Token indices for routed_input with shape ``(bs*slen*top_k,)``. - num_local_tokens_per_expert (torch.Tensor): + num_tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert with shape ``(num_experts,)``. """ # scores shape (bs*slen, num_experts) @@ -146,13 +169,18 @@ def forward( # top scores shape (bs*slen, top_k) # NOTE: The expert_bias is only used for routing. The gating value # top_scores is still derived from the original scores. - _, selected_experts_indices = torch.topk( - scores + expert_bias, k=self.top_k, dim=1 - ) - top_scores = scores.gather(dim=1, index=selected_experts_indices) + if expert_bias is not None: + _, selected_experts_indices = torch.topk( + scores + expert_bias, k=self.top_k, dim=1 + ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) + else: + top_scores, selected_experts_indices = torch.topk( + scores, k=self.top_k, dim=1 + ) # group tokens together by expert indices from 0 to num_experts and pass that to experts forward - num_local_tokens_per_expert = torch.histc( + num_tokens_per_expert = torch.histc( selected_experts_indices.view(-1), bins=self.num_experts, min=0, @@ -165,7 +193,7 @@ def forward( top_scores = top_scores.view(-1)[token_indices_experts_sorted] token_indices_experts_sorted = token_indices_experts_sorted // self.top_k - return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert + return top_scores, token_indices_experts_sorted, num_tokens_per_expert def init_weights(self, init_std: float): nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) @@ -191,12 +219,11 @@ def __init__(self, model_args: TransformerModelArgs): hidden_dim = int(hidden_dim / hidden_dim_denom) hidden_dim += -hidden_dim % model_args.multiple_of - self.use_grouped_mm = model_args.use_grouped_mm self.experts = GroupedExperts( dim=dim, hidden_dim=hidden_dim, num_experts=num_experts, - use_grouped_mm=self.use_grouped_mm, + use_grouped_mm=model_args.use_grouped_mm, ) self.router = TokenChoiceTopKRouter( dim=dim, num_experts=num_experts, top_k=model_args.top_k @@ -206,40 +233,31 @@ def __init__(self, model_args: TransformerModelArgs): dim=dim, hidden_dim=hidden_dim, num_experts=1, - use_grouped_mm=self.use_grouped_mm, + use_grouped_mm=model_args.use_grouped_mm, ) if model_args.use_shared_expert else None ) - # auxiliary-loss-free load balancing + # define fields for auxiliary-loss-free load balancing (https://arxiv.org/abs/2408.15664) + # NOTE: tokens_per_expert is accumulated in the model forward pass. + # expert_bias is updated outside the model in an optimzer step pre hook + # to work with gradient accumulation. self.load_balance_coeff = model_args.load_balance_coeff - # the fields below are defined even when load_balance_coeff is None - # to make initialization and checkpointing code simpler - self.register_buffer( - "expert_bias", - torch.zeros(num_experts, dtype=torch.float32), - persistent=True, - ) - self.register_buffer( - "tokens_per_expert", - torch.zeros(num_experts, dtype=torch.float32), - persistent=True, - ) - - # NOTE: forward hook, forward pre hook, or backward pre hook - # would conflict with activation checkpointing - if self.load_balance_coeff is not None and self.load_balance_coeff > 0: - self.register_full_backward_hook(self._update_expert_bias) - - def _update_expert_bias(self, *_): - expert_bias_delta = self.load_balance_coeff * torch.sign( - self.tokens_per_expert.mean() - self.tokens_per_expert - ) - expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() - self.expert_bias.add_(expert_bias_delta) - - self.tokens_per_expert.zero_() + if self.load_balance_coeff is not None: + assert self.load_balance_coeff > 0.0 + self.register_buffer( + "expert_bias", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + self.register_buffer( + "tokens_per_expert", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + else: + self.expert_bias = None def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -252,15 +270,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: bs, slen, dim = x.shape # top_scores and selected_indices shape (bs*slen*top_k,) - # num_local_tokens_per_expert shape (num_experts,) + # num_tokens_per_expert shape (num_experts,) ( top_scores, token_indices, - num_local_tokens_per_expert, + num_tokens_per_expert, ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) - # will be used to update the expert bias for load balancing - self.tokens_per_expert += num_local_tokens_per_expert + # tokens_per_expert will be used to update the expert bias for load balancing. + # Prevent extra local tokens accumulation on evaluation or activation recomputation. + if self.load_balance_coeff is not None and torch.is_grad_enabled(): + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) # shape (bs*slen*top_k, dim) token_indices = token_indices.reshape(-1, 1).expand(-1, dim) @@ -275,41 +296,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x.dtype ) - if self.use_grouped_mm: - # NOTE: In order to use torch._grouped_mm, we need to make sure - # the number of tokens each expert gets is a multiple of 16. - # The following kernel helps achieve this via padding, without - # incurring synchronization between device and host. - from torchtitan.experiments.kernels.moe.indices import ( - generate_permute_indices, - ) - - ALIGN_SIZE_M = 16 - - with torch.no_grad(): - ( - permuted_indices, - num_local_tokens_per_expert, - _, - ) = generate_permute_indices( - num_local_tokens_per_expert, - self.experts.num_experts, - 1, - token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M, - ALIGN_SIZE_M, - ) - token_indices = torch.vstack( - (token_indices, token_indices.new_zeros((dim))) - ) - token_indices = token_indices[permuted_indices, :] - routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim)))) - routed_input = routed_input[permuted_indices, :] - else: - # NOTE: this would incur a synchronization between device and host - num_local_tokens_per_expert = num_local_tokens_per_expert.tolist() - # shape (bs*slen*top_k, dim) - routed_output = self.experts(routed_input, num_local_tokens_per_expert) + routed_output = self.experts(routed_input, num_tokens_per_expert) # shared expert if self.shared_expert is not None: @@ -333,10 +321,11 @@ def init_weights( if self.shared_expert is not None: self.shared_expert.init_weights(init_std) - with torch.device(buffer_device): - self.expert_bias = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) - self.tokens_per_expert = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) + if self.load_balance_coeff is not None: + with torch.device(buffer_device): + self.expert_bias = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) + self.tokens_per_expert = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) diff --git a/torchtitan/experiments/llama4/optimizer.py b/torchtitan/experiments/llama4/optimizer.py new file mode 100644 index 000000000..d4829de88 --- /dev/null +++ b/torchtitan/experiments/llama4/optimizer.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.components.ft import FTManager +from torchtitan.components.optimizer import build_optimizers, OptimizersContainer +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + + +# for MoE auxiliary-loss-free load balancing +def _update_expert_bias( + model_parts: list[nn.Module], + world_mesh: dict[str, DeviceMesh], + parallel_dims: ParallelDims, +): + dp_cp_mesh = world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None + # TODO: Currently this sync is blocking (thus exposed) and happens on the + # default compute stream. Need to assess if this is OK performance-wise. + for model_part in model_parts: + for transformer_block in model_part.layers.values(): + if transformer_block.moe_enabled: + moe = transformer_block.moe + if moe.load_balance_coeff is None: + return + + if dp_cp_mesh is not None: + torch.distributed.all_reduce( + moe.tokens_per_expert, group=dp_cp_mesh.get_group() + ) + + with torch.no_grad(): + expert_bias_delta = moe.load_balance_coeff * torch.sign( + moe.tokens_per_expert.mean() - moe.tokens_per_expert + ) + expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() + moe.expert_bias.add_(expert_bias_delta) + moe.tokens_per_expert.zero_() + + +def build_llama4_optimizers( + model_parts: list[nn.Module], + job_config: JobConfig, + parallel_dims: ParallelDims, + world_mesh: DeviceMesh, + ft_manager: FTManager, +) -> OptimizersContainer: + optimizers = build_optimizers( + model_parts=model_parts, + job_config=job_config, + parallel_dims=parallel_dims, + world_mesh=world_mesh, + ft_manager=ft_manager, + ) + + optimizers.register_step_pre_hook( + lambda *args, **kwargs: _update_expert_bias( + model_parts, world_mesh=world_mesh, parallel_dims=parallel_dims + ) + ) + + return optimizers diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index 41a8459ea..7fbe95e19 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -52,6 +52,7 @@ tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 context_parallel_degree = 1 +expert_parallel_degree = 1 [checkpoint] enable_checkpoint = false diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 90a98f9b0..df395adcb 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -90,9 +90,8 @@ def parallelize_llama( if job_config.training.compile: apply_compile(model) - if ( - parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled - ): # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.fsdp_enabled: + # apply FSDP or HSDP, potentially with Context Parallel if parallel_dims.dp_replicate_enabled: dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") else: diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index ddb6961e5..aeff047b5 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.pipelining.schedules import _PipelineSchedule from torchtitan.components.dataloader import BaseDataLoader @@ -23,6 +24,7 @@ from torchtitan.components.optimizer import OptimizersContainer from torchtitan.components.tokenizer import Tokenizer from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims DeviceType = int | str | torch.device @@ -71,7 +73,8 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: TokenizerBuilder: TypeAlias = Callable[..., Tokenizer] MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor] OptimizersBuilder: TypeAlias = Callable[ - [list[nn.Module], JobConfig, FTManager], OptimizersContainer + [list[nn.Module], JobConfig, ParallelDims, DeviceMesh, FTManager], + OptimizersContainer, ] LRSchedulersBuilder: TypeAlias = Callable[ [OptimizersContainer, JobConfig], LRSchedulersContainer diff --git a/torchtitan/train.py b/torchtitan/train.py index ca1480f2e..e6a1ffa7d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -89,6 +89,7 @@ def __init__(self, job_config: JobConfig): cp=parallelism_config.context_parallel_degree, tp=parallelism_config.tensor_parallel_degree, pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, world_size=world_size, enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) @@ -280,7 +281,7 @@ def __init__(self, job_config: JobConfig): # build optimizer after applying parallelisms to the model self.optimizers = self.train_spec.build_optimizers_fn( - self.model_parts, job_config, self.ft_manager + self.model_parts, job_config, parallel_dims, world_mesh, self.ft_manager ) self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( self.optimizers, job_config @@ -436,6 +437,7 @@ def train_step( self.job_config.training.max_norm, foreach=True, pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, + parallel_dims=parallel_dims, ) self.checkpointer.maybe_wait_for_staging() self.optimizers.step()