diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index d681cd6a1..03c97fcb1 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -21,11 +21,7 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims -from torchtitan.models.llama3.infra.parallelize import ( - apply_ac, - apply_compile, - apply_ddp, -) +from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp from torchtitan.tools.logging import logger from .expert_parallel import ( @@ -36,6 +32,28 @@ ) +def apply_compile(model: nn.Module): + """ + Apply torch.compile to each TransformerBlock, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ + torch._dynamo.config.fail_on_recompile_limit_hit = True + for layer_id, transformer_block in model.layers.named_children(): + if transformer_block.moe_enabled: + moe = transformer_block.moe + # Individually compile modules to keep fullgraph=True on FSDP wrapped experts + moe.experts = torch.compile(moe.experts, fullgraph=True) + moe.shared_expert = torch.compile(moe.shared_expert, fullgraph=True) + + # Separately compile the code around the FSDP wrapped experts + moe.router = torch.compile(moe.router, fullgraph=True) + else: + transformer_block = torch.compile(transformer_block, fullgraph=True) + model.layers.register_module(layer_id, transformer_block) + + logger.info("Compiling each TransformerBlock with torch.compile") + + def parallelize_llama( model: nn.Module, world_mesh: DeviceMesh, diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index d7f0ce3fd..dbf9623a8 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -13,6 +13,75 @@ from .args import TransformerModelArgs +# TODO: keeping this for-loop implementation for comparison +# and readability, may remove later +@expert_parallel +def _run_experts_for_loop( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, +) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = F.silu(torch.matmul(x_expert, w1[expert_idx])) + h = h * torch.matmul(x_expert, w3[expert_idx]) + h = torch.matmul(h, w2[expert_idx]) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, w1)) + h = h * torch.bmm(x, w3) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, w2) + + return out + + +@expert_parallel +def _run_experts_grouped_mm( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, +) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + else: + offsets = None + # fall back to regular bmm between 3D tensors + assert x.dim() == 3 + + h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) + h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) + out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) + + return out + + class GroupedExperts(nn.Module): def __init__( self, @@ -34,83 +103,14 @@ def forward( num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: if self.use_grouped_mm: - return GroupedExperts._run_experts_grouped_mm( + return _run_experts_grouped_mm( self.w1, self.w2, self.w3, x, num_tokens_per_expert ) else: - return GroupedExperts._run_experts_for_loop( + return _run_experts_for_loop( self.w1, self.w2, self.w3, x, num_tokens_per_expert ) - # TODO: keeping this for-loop implementation for comparison - # and readability, may remove later - @expert_parallel - @staticmethod - def _run_experts_for_loop( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - # NOTE: this would incur a synchronization between device and host - num_tokens_per_expert = num_tokens_per_expert.tolist() - - # side-effect code due to the usage of generate_permute_indices - num_padding = x.shape[0] - sum(num_tokens_per_expert) - - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - h = F.silu(torch.matmul(x_expert, w1[expert_idx])) - h = h * torch.matmul(x_expert, w3[expert_idx]) - h = torch.matmul(h, w2[expert_idx]) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - - # side-effect code due to the usage of generate_permute_indices - out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) - else: - # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, w1)) - h = h * torch.bmm(x, w3) - # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, w2) - - return out - - @expert_parallel - @staticmethod - def _run_experts_grouped_mm( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # grouped mm between a 2D tensor and a 3D tensor - assert x.dim() == 2 - else: - offsets = None - # fall back to regular bmm between 3D tensors - assert x.dim() == 3 - - h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) - h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) - out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) - - return out - def init_weights(self, init_std: float): nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)