Skip to content

[llama4][auxiliary-loss-free load balancing] update expert_bias without backward hooks #1304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ class Parallelism:
The default value is 'allgather'.
"""

enable_tp2ep: bool = False
"""Whether to use expert parallelism instead of tensor parallelism for shared experts."""


@dataclass
class Checkpoint:
Expand Down
9 changes: 9 additions & 0 deletions torchtitan/distributed/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ParallelDims:
pp: int
world_size: int
enable_loss_parallel: bool
enable_tp2ep: bool

def __post_init__(self):
self._validate()
Expand Down Expand Up @@ -81,17 +82,23 @@ def _build_mesh(
dp_shard_cp_mesh_dim_names = []
# Mesh for loss all-reduce
dp_cp_mesh_dim_names = []
dp_cp_tp2ep_mesh_dim_names = []

if self.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")
dp_cp_mesh_dim_names.append("dp_replicate")
dp_cp_tp2ep_mesh_dim_names.append("dp_replicate")
if self.dp_shard_enabled:
dp_mesh_dim_names.append("dp_shard")
dp_shard_cp_mesh_dim_names.append("dp_shard")
dp_cp_mesh_dim_names.append("dp_shard")
dp_cp_tp2ep_mesh_dim_names.append("dp_shard")
if self.cp_enabled:
dp_shard_cp_mesh_dim_names.append("cp")
dp_cp_mesh_dim_names.append("cp")
dp_cp_tp2ep_mesh_dim_names.append("cp")
if self.tp_enabled and self.enable_tp2ep:
dp_cp_tp2ep_mesh_dim_names.append("tp")

if dp_mesh_dim_names != []:
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
Expand All @@ -101,6 +108,8 @@ def _build_mesh(
)
if dp_cp_mesh_dim_names != []:
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
if dp_cp_tp2ep_mesh_dim_names != []:
mesh[tuple(dp_cp_tp2ep_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp_tp2ep")

return mesh

Expand Down
6 changes: 3 additions & 3 deletions torchtitan/experiments/llama4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torchtitan.models.llama3 import pipeline_llama
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .infra.parallelize_llama import parallelize_llama
from .infra.parallelize_llama import parallelize_llama, update_router_expert_bias
from .model.args import TransformerModelArgs
from .model.model import Transformer

Expand Down Expand Up @@ -103,5 +103,5 @@
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_tiktoken_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
)
finalize_model_grads_func=update_router_expert_bias,
))
113 changes: 91 additions & 22 deletions torchtitan/experiments/llama4/infra/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,28 +117,6 @@ def parallelize_llama(
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
)

# for MoE auxiliary-loss-free load balancing
if dp_mesh is not None:
# NOTE: Currently this sync is blocking (thus exposed) and happens on the
# default compute stream. Need to assess if this is OK performance-wise.
def _sync_tokens_per_expert(module, *_):
assert isinstance(module, MoE)
torch.distributed.all_reduce(
module.tokens_per_expert, group=dp_mesh.get_group()
)

for transformer_block in model.layers.values():
if transformer_block.moe_enabled:
load_balance_coeff = transformer_block.moe.load_balance_coeff
if load_balance_coeff is not None and load_balance_coeff > 0:
# prepend=True so that the sync runs before
# the _update_expert_bias hook in MoE
transformer_block.moe.register_full_backward_hook(
_sync_tokens_per_expert, prepend=True
)
else:
break

return model


Expand Down Expand Up @@ -176,3 +154,94 @@ def apply_moe_tp(
device_mesh=tp_mesh,
parallelize_plan=moe_layer_plan,
)


def get_updated_expert_bias(
tokens_per_expert,
expert_bias,
expert_bias_update_rate,
reduce_mesh: DeviceMesh | None,
):
"""Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1#

Args:
tokens_per_expert (torch.Tensor): The number of tokens assigned to each expert.
expert_bias (torch.Tensor): The bias for each expert.
expert_bias_udpate_rate (float): The update rate for the expert bias.
"""

with torch.no_grad():
# All Reduce Across DPxCPxTP2EP group
if reduce_mesh is not None:
torch.distributed.all_reduce(
tokens_per_expert,
group=reduce_mesh.get_group(),
)
average_tokens = tokens_per_expert.sum(
dim=-1, keepdim=True) / tokens_per_expert.shape[-1]
offset = average_tokens - tokens_per_expert
updated_expert_bias = expert_bias + torch.sign(
offset) * expert_bias_update_rate
return updated_expert_bias


# for MoE auxiliary-loss-free load balancing
def update_router_expert_bias(model: torch.nn.Module,
world_mesh: DeviceMesh):
"""
Update the expert bias of the router for a global batch.
This requires all-reduce of tokens_per_expert across DPxCPxTP2EP ranks
"""
tokens_per_expert_list = []
expert_bias_list = []
global_load_balance_coeff = None

if hasattr(model, "layers"):
layers = model.layers
elif hasattr(model, "model") and hasattr(model.model, "layers"):
layers = model.model.layers
else:
raise NotImplementedError(
"Model structure not recognized for MoE expert bias update."
)
for transformer_block in layers.values():
if transformer_block.moe_enabled:
load_balance_coeff = transformer_block.moe.load_balance_coeff
if load_balance_coeff is not None and load_balance_coeff > 0:
if global_load_balance_coeff is None:
global_load_balance_coeff = load_balance_coeff
else:
assert (
global_load_balance_coeff == load_balance_coeff
), "All MoE layers must have the same load balance coefficient."
tokens_per_expert_list.append(
transformer_block.moe.tokens_per_expert)
expert_bias_list.append(transformer_block.moe.expert_bias)
else:
break

if len(expert_bias_list) == 0:
return

if world_mesh is not None and world_mesh.size() > 1:
try:
load_balance_reduce_mesh = world_mesh["dp_cp_tp2ep"]
except KeyError:
load_balance_reduce_mesh = None
else:
load_balance_reduce_mesh = None

stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0)
stacked_expert_bias = torch.stack(expert_bias_list, dim=0)
stacked_updated_expert_bias = get_updated_expert_bias(
stacked_tokens_per_expert,
stacked_expert_bias,
global_load_balance_coeff,
load_balance_reduce_mesh,
)

for tokens_per_expert, expert_bias, updated_expert_bias in zip(
tokens_per_expert_list, expert_bias_list,
stacked_updated_expert_bias):
tokens_per_expert.zero_()
expert_bias.copy_(updated_expert_bias)
19 changes: 5 additions & 14 deletions torchtitan/experiments/llama4/model/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __init__(self, model_args: TransformerModelArgs):

# auxiliary-loss-free load balancing
self.load_balance_coeff = model_args.load_balance_coeff
self.expert_bias_enabled = self.load_balance_coeff is not None and self.load_balance_coeff > 0.0
# the fields below are defined even when load_balance_coeff is None
# to make initialization and checkpointing code simpler
self.register_buffer(
Expand All @@ -227,19 +228,6 @@ def __init__(self, model_args: TransformerModelArgs):
persistent=True,
)

# NOTE: forward hook, forward pre hook, or backward pre hook
# would conflict with activation checkpointing
if self.load_balance_coeff is not None and self.load_balance_coeff > 0:
self.register_full_backward_hook(self._update_expert_bias)

def _update_expert_bias(self, *_):
expert_bias_delta = self.load_balance_coeff * torch.sign(
self.tokens_per_expert.mean() - self.tokens_per_expert
)
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
self.expert_bias.add_(expert_bias_delta)

self.tokens_per_expert.zero_()

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -260,7 +248,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
) = self.router(x.reshape(bs * slen, dim), self.expert_bias)

# will be used to update the expert bias for load balancing
self.tokens_per_expert += num_local_tokens_per_expert
# Prevent extra local tokens accumulation on evaluation or activation recomputation
if self.expert_bias_enabled and torch.is_grad_enabled():
with torch.no_grad():
self.tokens_per_expert.add_(num_local_tokens_per_expert)

# shape (bs*slen*top_k, dim)
token_indices = token_indices.reshape(-1, 1).expand(-1, dim)
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/protocols/train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.pipelining.schedules import _PipelineSchedule

from torchtitan.components.dataloader import BaseDataLoader
Expand Down Expand Up @@ -74,6 +75,7 @@ def from_model_args(cls, args: BaseModelArgs) -> nn.Module:
[OptimizersContainer, JobConfig], LRSchedulersContainer
]
LossFunctionBuilder: TypeAlias = Callable[..., LossFunction]
FinalizeModelGradsFunc: TypeAlias = Callable[[nn.Module, DeviceMesh], None]


@dataclass
Expand All @@ -89,6 +91,7 @@ class TrainSpec:
build_tokenizer_fn: TokenizerBuilder | None
build_loss_fn: LossFunctionBuilder
build_metrics_processor_fn: MetricsProcessorBuilder | None = None
finalize_model_grads_func: FinalizeModelGradsFunc | None = None


_train_specs = {}
Expand Down
7 changes: 7 additions & 0 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, job_config: JobConfig):
pp=parallelism_config.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
enable_tp2ep=parallelism_config.enable_tp2ep,
)
dist_utils.init_distributed(job_config)

Expand Down Expand Up @@ -425,6 +426,12 @@ def train_step(
loss = self.forward_backward_step(input_dict, labels)
accumulated_losses.append(loss.detach())

# TODO: this can be placed inside PP but might break gradient accumulation
if self.train_spec.finalize_model_grads_func is not None:
for m in self.model_parts:
self.train_spec.finalize_model_grads_func(
m, self.world_mesh)

dist_utils.clip_grad_norm_(
[p for m in self.model_parts for p in m.parameters()],
self.job_config.training.max_norm,
Expand Down