Skip to content

[WIP] Compile for dp2ep #1365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims

from torchtitan.models.llama3.infra.parallelize import (
apply_ac,
apply_compile,
apply_ddp,
)
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
from torchtitan.tools.logging import logger

from .expert_parallel import (
Expand All @@ -36,6 +32,28 @@
)


def apply_compile(model: nn.Module):
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
torch._dynamo.config.fail_on_recompile_limit_hit = True
for layer_id, transformer_block in model.layers.named_children():
if transformer_block.moe_enabled:
moe = transformer_block.moe
# Individually compile modules to keep fullgraph=True on FSDP wrapped experts
moe.experts = torch.compile(moe.experts, fullgraph=True)
moe.shared_expert = torch.compile(moe.shared_expert, fullgraph=True)

# Separately compile the code around the FSDP wrapped experts
moe.router = torch.compile(moe.router, fullgraph=True)
else:
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")


def parallelize_llama(
model: nn.Module,
world_mesh: DeviceMesh,
Expand Down
142 changes: 71 additions & 71 deletions torchtitan/experiments/llama4/model/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,75 @@
from .args import TransformerModelArgs


# TODO: keeping this for-loop implementation for comparison
# and readability, may remove later
@expert_parallel
def _run_experts_for_loop(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor | None = None,
) -> torch.Tensor:
if num_tokens_per_expert is not None:
# NOTE: this would incur a synchronization between device and host
num_tokens_per_expert = num_tokens_per_expert.tolist()

# side-effect code due to the usage of generate_permute_indices
num_padding = x.shape[0] - sum(num_tokens_per_expert)

# a tuple of tensors indexed by experts
# each with shape (tokens_per_expert(varying), dim)
x = torch.split(
x[: sum(num_tokens_per_expert)],
split_size_or_sections=num_tokens_per_expert,
dim=0,
)
out_experts_splits = []
for expert_idx, x_expert in enumerate(x):
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
h = h * torch.matmul(x_expert, w3[expert_idx])
h = torch.matmul(h, w2[expert_idx])
# h shape (tokens_per_expert(varying), dim)
out_experts_splits.append(h)
out = torch.cat(out_experts_splits, dim=0)

# side-effect code due to the usage of generate_permute_indices
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
else:
# x shape (num_experts, tokens_per_expert, dim)
h = F.silu(torch.bmm(x, w1))
h = h * torch.bmm(x, w3)
# out shape (num_experts, tokens_per_expert, dim)
out = torch.bmm(h, w2)

return out


@expert_parallel
def _run_experts_grouped_mm(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor | None = None,
) -> torch.Tensor:
if num_tokens_per_expert is not None:
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
# grouped mm between a 2D tensor and a 3D tensor
assert x.dim() == 2
else:
offsets = None
# fall back to regular bmm between 3D tensors
assert x.dim() == 3

h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)

return out


class GroupedExperts(nn.Module):
def __init__(
self,
Expand All @@ -34,83 +103,14 @@ def forward(
num_tokens_per_expert: torch.Tensor | None = None,
) -> torch.Tensor:
if self.use_grouped_mm:
return GroupedExperts._run_experts_grouped_mm(
return _run_experts_grouped_mm(
self.w1, self.w2, self.w3, x, num_tokens_per_expert
)
else:
return GroupedExperts._run_experts_for_loop(
return _run_experts_for_loop(
self.w1, self.w2, self.w3, x, num_tokens_per_expert
)

# TODO: keeping this for-loop implementation for comparison
# and readability, may remove later
@expert_parallel
@staticmethod
def _run_experts_for_loop(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor | None = None,
) -> torch.Tensor:
if num_tokens_per_expert is not None:
# NOTE: this would incur a synchronization between device and host
num_tokens_per_expert = num_tokens_per_expert.tolist()

# side-effect code due to the usage of generate_permute_indices
num_padding = x.shape[0] - sum(num_tokens_per_expert)

# a tuple of tensors indexed by experts
# each with shape (tokens_per_expert(varying), dim)
x = torch.split(
x[: sum(num_tokens_per_expert)],
split_size_or_sections=num_tokens_per_expert,
dim=0,
)
out_experts_splits = []
for expert_idx, x_expert in enumerate(x):
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
h = h * torch.matmul(x_expert, w3[expert_idx])
h = torch.matmul(h, w2[expert_idx])
# h shape (tokens_per_expert(varying), dim)
out_experts_splits.append(h)
out = torch.cat(out_experts_splits, dim=0)

# side-effect code due to the usage of generate_permute_indices
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
else:
# x shape (num_experts, tokens_per_expert, dim)
h = F.silu(torch.bmm(x, w1))
h = h * torch.bmm(x, w3)
# out shape (num_experts, tokens_per_expert, dim)
out = torch.bmm(h, w2)

return out

@expert_parallel
@staticmethod
def _run_experts_grouped_mm(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor | None = None,
) -> torch.Tensor:
if num_tokens_per_expert is not None:
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
# grouped mm between a 2D tensor and a 3D tensor
assert x.dim() == 2
else:
offsets = None
# fall back to regular bmm between 3D tensors
assert x.dim() == 3

h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)

return out

def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
Expand Down