From fdea77827cfd50f6093cd92e3b612f79a509d2e5 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 16 Jun 2025 06:14:35 +0000 Subject: [PATCH 1/2] feat: update expert_bias without backward hooks --- torchtitan/config_manager.py | 3 + torchtitan/distributed/parallel_dims.py | 9 ++ torchtitan/experiments/llama4/__init__.py | 6 +- .../llama4/infra/parallelize_llama.py | 113 ++++++++++++++---- torchtitan/experiments/llama4/model/moe.py | 19 +-- torchtitan/protocols/train_spec.py | 3 + torchtitan/train.py | 7 ++ 7 files changed, 121 insertions(+), 39 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index f12b21ba5..5720d249c 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -361,6 +361,9 @@ class Parallelism: The default value is 'allgather'. """ + enable_tp2ep: bool = False + """Whether to use expert parallelism instead of tensor parallelism for shared experts.""" + @dataclass class Checkpoint: diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 5f8bc5025..6d7bfe0f6 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -25,6 +25,7 @@ class ParallelDims: pp: int world_size: int enable_loss_parallel: bool + enable_tp2ep: bool def __post_init__(self): self._validate() @@ -81,17 +82,23 @@ def _build_mesh( dp_shard_cp_mesh_dim_names = [] # Mesh for loss all-reduce dp_cp_mesh_dim_names = [] + dp_cp_tp2ep_mesh_dim_names = [] if self.dp_replicate_enabled: dp_mesh_dim_names.append("dp_replicate") dp_cp_mesh_dim_names.append("dp_replicate") + dp_cp_tp2ep_mesh_dim_names.append("dp_replicate") if self.dp_shard_enabled: dp_mesh_dim_names.append("dp_shard") dp_shard_cp_mesh_dim_names.append("dp_shard") dp_cp_mesh_dim_names.append("dp_shard") + dp_cp_tp2ep_mesh_dim_names.append("dp_shard") if self.cp_enabled: dp_shard_cp_mesh_dim_names.append("cp") dp_cp_mesh_dim_names.append("cp") + dp_cp_tp2ep_mesh_dim_names.append("cp") + if self.tp_enabled and self.enable_tp2ep: + dp_cp_tp2ep_mesh_dim_names.append("tp") if dp_mesh_dim_names != []: mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") @@ -101,6 +108,8 @@ def _build_mesh( ) if dp_cp_mesh_dim_names != []: mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") + if dp_cp_tp2ep_mesh_dim_names != []: + mesh[tuple(dp_cp_tp2ep_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp_tp2ep") return mesh diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py index 595a0733e..ee44478f3 100644 --- a/torchtitan/experiments/llama4/__init__.py +++ b/torchtitan/experiments/llama4/__init__.py @@ -12,7 +12,7 @@ from torchtitan.models.llama3 import pipeline_llama from torchtitan.protocols.train_spec import register_train_spec, TrainSpec -from .infra.parallelize_llama import parallelize_llama +from .infra.parallelize_llama import parallelize_llama, update_router_expert_bias from .model.args import TransformerModelArgs from .model.model import Transformer @@ -103,5 +103,5 @@ build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_tiktoken_tokenizer, build_loss_fn=build_cross_entropy_loss, - ) -) + finalize_model_grads_func=update_router_expert_bias, + )) diff --git a/torchtitan/experiments/llama4/infra/parallelize_llama.py b/torchtitan/experiments/llama4/infra/parallelize_llama.py index 785d9d8a5..3c80a3aab 100644 --- a/torchtitan/experiments/llama4/infra/parallelize_llama.py +++ b/torchtitan/experiments/llama4/infra/parallelize_llama.py @@ -117,28 +117,6 @@ def parallelize_llama( enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) - # for MoE auxiliary-loss-free load balancing - if dp_mesh is not None: - # NOTE: Currently this sync is blocking (thus exposed) and happens on the - # default compute stream. Need to assess if this is OK performance-wise. - def _sync_tokens_per_expert(module, *_): - assert isinstance(module, MoE) - torch.distributed.all_reduce( - module.tokens_per_expert, group=dp_mesh.get_group() - ) - - for transformer_block in model.layers.values(): - if transformer_block.moe_enabled: - load_balance_coeff = transformer_block.moe.load_balance_coeff - if load_balance_coeff is not None and load_balance_coeff > 0: - # prepend=True so that the sync runs before - # the _update_expert_bias hook in MoE - transformer_block.moe.register_full_backward_hook( - _sync_tokens_per_expert, prepend=True - ) - else: - break - return model @@ -176,3 +154,94 @@ def apply_moe_tp( device_mesh=tp_mesh, parallelize_plan=moe_layer_plan, ) + + +def get_updated_expert_bias( + tokens_per_expert, + expert_bias, + expert_bias_update_rate, + reduce_mesh: DeviceMesh | None, +): + """Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1# + + Args: + tokens_per_expert (torch.Tensor): The number of tokens assigned to each expert. + expert_bias (torch.Tensor): The bias for each expert. + expert_bias_udpate_rate (float): The update rate for the expert bias. + """ + + with torch.no_grad(): + # All Reduce Across TPxCPxDP group + if reduce_mesh is not None: + torch.distributed.all_reduce( + tokens_per_expert, + group=reduce_mesh.get_group(), + ) + average_tokens = tokens_per_expert.sum( + dim=-1, keepdim=True) / tokens_per_expert.shape[-1] + offset = average_tokens - tokens_per_expert + updated_expert_bias = expert_bias + torch.sign( + offset) * expert_bias_update_rate + return updated_expert_bias + + +# for MoE auxiliary-loss-free load balancing +def update_router_expert_bias(model: torch.nn.Module, + world_mesh: DeviceMesh): + """ + Update the expert bias of the router for a global batch. + This requires all-reduce of tokens_per_expert across DPxCPxTP2EP ranks + """ + tokens_per_expert_list = [] + expert_bias_list = [] + global_load_balance_coeff = None + + if hasattr(model, "layers"): + layers = model.layers + elif hasattr(model, "model") and hasattr(model.model, "layers"): + layers = model.model.layers + else: + raise NotImplementedError( + "Model structure not recognized for MoE expert bias update." + ) + for transformer_block in layers.values(): + if transformer_block.moe_enabled: + load_balance_coeff = transformer_block.moe.load_balance_coeff + if load_balance_coeff is not None and load_balance_coeff > 0: + if global_load_balance_coeff is None: + global_load_balance_coeff = load_balance_coeff + else: + assert ( + global_load_balance_coeff == load_balance_coeff + ), "All MoE layers must have the same load balance coefficient." + tokens_per_expert_list.append( + transformer_block.moe.tokens_per_expert) + expert_bias_list.append(transformer_block.moe.expert_bias) + else: + break + + if len(expert_bias_list) == 0: + return + + if world_mesh is not None and world_mesh.size() > 1: + try: + load_balance_reduce_mesh = world_mesh["dp_cp_tp2ep"] + except KeyError: + load_balance_reduce_mesh = None + else: + load_balance_reduce_mesh = None + + stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0) + stacked_expert_bias = torch.stack(expert_bias_list, dim=0) + stacked_updated_expert_bias = get_updated_expert_bias( + stacked_tokens_per_expert, + stacked_expert_bias, + global_load_balance_coeff, + load_balance_reduce_mesh, + ) + + for tokens_per_expert, expert_bias, updated_expert_bias in zip( + tokens_per_expert_list, expert_bias_list, + stacked_updated_expert_bias): + tokens_per_expert.zero_() + expert_bias.copy_(updated_expert_bias) diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index 0dad02d25..cbfd0ffc0 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -214,6 +214,7 @@ def __init__(self, model_args: TransformerModelArgs): # auxiliary-loss-free load balancing self.load_balance_coeff = model_args.load_balance_coeff + self.expert_bias_enabled = self.load_balance_coeff is not None and self.load_balance_coeff > 0.0 # the fields below are defined even when load_balance_coeff is None # to make initialization and checkpointing code simpler self.register_buffer( @@ -227,19 +228,6 @@ def __init__(self, model_args: TransformerModelArgs): persistent=True, ) - # NOTE: forward hook, forward pre hook, or backward pre hook - # would conflict with activation checkpointing - if self.load_balance_coeff is not None and self.load_balance_coeff > 0: - self.register_full_backward_hook(self._update_expert_bias) - - def _update_expert_bias(self, *_): - expert_bias_delta = self.load_balance_coeff * torch.sign( - self.tokens_per_expert.mean() - self.tokens_per_expert - ) - expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() - self.expert_bias.add_(expert_bias_delta) - - self.tokens_per_expert.zero_() def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -260,7 +248,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) # will be used to update the expert bias for load balancing - self.tokens_per_expert += num_local_tokens_per_expert + # Prevent extra local tokens accumulation on evaluation or activation recomputation + if self.expert_bias_enabled and torch.is_grad_enabled(): + with torch.no_grad(): + self.tokens_per_expert.add_(num_local_tokens_per_expert) # shape (bs*slen*top_k, dim) token_indices = token_indices.reshape(-1, 1).expand(-1, dim) diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index 922f9a8d0..1396b10e6 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.pipelining.schedules import _PipelineSchedule from torchtitan.components.dataloader import BaseDataLoader @@ -74,6 +75,7 @@ def from_model_args(cls, args: BaseModelArgs) -> nn.Module: [OptimizersContainer, JobConfig], LRSchedulersContainer ] LossFunctionBuilder: TypeAlias = Callable[..., LossFunction] +FinalizeModelGradsFunc: TypeAlias = Callable[[nn.Module, DeviceMesh], None] @dataclass @@ -89,6 +91,7 @@ class TrainSpec: build_tokenizer_fn: TokenizerBuilder | None build_loss_fn: LossFunctionBuilder build_metrics_processor_fn: MetricsProcessorBuilder | None = None + finalize_model_grads_func: FinalizeModelGradsFunc | None = None _train_specs = {} diff --git a/torchtitan/train.py b/torchtitan/train.py index 9340671d7..6d8a86fdd 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -91,6 +91,7 @@ def __init__(self, job_config: JobConfig): pp=parallelism_config.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=not parallelism_config.disable_loss_parallel, + enable_tp2ep=parallelism_config.enable_tp2ep, ) dist_utils.init_distributed(job_config) @@ -425,6 +426,12 @@ def train_step( loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) + # TODO: this can be placed inside PP but might break gradient accumulation + if self.train_spec.finalize_model_grads_func is not None: + for m in self.model_parts: + self.train_spec.finalize_model_grads_func( + m, self.world_mesh) + dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, From 354f336711d9ac8a377a9608771eda96569f61d5 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 16 Jun 2025 06:57:13 +0000 Subject: [PATCH 2/2] chore: comments for the all_reduce group --- torchtitan/experiments/llama4/infra/parallelize_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/experiments/llama4/infra/parallelize_llama.py b/torchtitan/experiments/llama4/infra/parallelize_llama.py index 3c80a3aab..cc0847b62 100644 --- a/torchtitan/experiments/llama4/infra/parallelize_llama.py +++ b/torchtitan/experiments/llama4/infra/parallelize_llama.py @@ -171,7 +171,7 @@ def get_updated_expert_bias( """ with torch.no_grad(): - # All Reduce Across TPxCPxDP group + # All Reduce Across DPxCPxTP2EP group if reduce_mesh is not None: torch.distributed.all_reduce( tokens_per_expert,