Skip to content

Commit 2ef0e26

Browse files
committed
compile, but there's a hang for grouped_mm autotuning
1 parent be15836 commit 2ef0e26

File tree

2 files changed

+107
-76
lines changed

2 files changed

+107
-76
lines changed

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,7 @@
2121
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
2222
from torchtitan.distributed import ParallelDims
2323

24-
from torchtitan.models.llama3.infra.parallelize import (
25-
apply_ac,
26-
apply_compile,
27-
apply_ddp,
28-
)
24+
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
2925
from torchtitan.tools.logging import logger
3026

3127
from .expert_parallel import (
@@ -36,6 +32,41 @@
3632
)
3733

3834

35+
def apply_compile(model: nn.Module):
36+
"""
37+
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
38+
repeated structure. Alternatively one can compile the whole model (after applying DP).
39+
"""
40+
torch._dynamo.config.fail_on_recompile_limit_hit = True
41+
for layer_id, transformer_block in model.layers.named_children():
42+
if transformer_block.moe_enabled:
43+
# compile the experts directly which can be wrapped by fsdp
44+
moe = transformer_block.moe
45+
46+
# transformer_block.moe.experts = torch.compile(transformer_block.moe.experts, fullgraph=True))
47+
moe.experts = torch.compile(moe.experts, fullgraph=True)
48+
moe.router = torch.compile(moe.router, fullgraph=True)
49+
moe.shared_expert = torch.compile(moe.shared_expert, fullgraph=True)
50+
else:
51+
transformer_block = torch.compile(transformer_block, fullgraph=True)
52+
model.layers.register_module(layer_id, transformer_block)
53+
54+
# def _compile_child(parent:nn.Module, child_name: str, child: nn.Module):
55+
# parent.register_module(child_name, torch.compile(child, fullgraph=True))
56+
57+
# torch._dynamo.config.fail_on_recompile_limit_hit = True
58+
# for layer_id, transformer_block in model.layers.named_children():
59+
# if transformer_block.moe_enabled:
60+
# # compile the experts directly which can be wrapped by fsdp
61+
# moe = transformer_block.moe
62+
# # for submod_id, submod in moe.named_children():
63+
# # _compile_child(moe, submod_id, submod)
64+
# else:
65+
# _compile_child(transformer_block, layer_id, transformer_block)
66+
67+
logger.info("Compiling each TransformerBlock with torch.compile")
68+
69+
3970
def parallelize_llama(
4071
model: nn.Module,
4172
world_mesh: DeviceMesh,

torchtitan/experiments/llama4/model/moe.py

Lines changed: 71 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,75 @@
1313
from .args import TransformerModelArgs
1414

1515

16+
# TODO: keeping this for-loop implementation for comparison
17+
# and readability, may remove later
18+
@expert_parallel
19+
def _run_experts_for_loop(
20+
w1: torch.Tensor,
21+
w2: torch.Tensor,
22+
w3: torch.Tensor,
23+
x: torch.Tensor,
24+
num_tokens_per_expert: torch.Tensor | None = None,
25+
) -> torch.Tensor:
26+
if num_tokens_per_expert is not None:
27+
# NOTE: this would incur a synchronization between device and host
28+
num_tokens_per_expert = num_tokens_per_expert.tolist()
29+
30+
# side-effect code due to the usage of generate_permute_indices
31+
num_padding = x.shape[0] - sum(num_tokens_per_expert)
32+
33+
# a tuple of tensors indexed by experts
34+
# each with shape (tokens_per_expert(varying), dim)
35+
x = torch.split(
36+
x[: sum(num_tokens_per_expert)],
37+
split_size_or_sections=num_tokens_per_expert,
38+
dim=0,
39+
)
40+
out_experts_splits = []
41+
for expert_idx, x_expert in enumerate(x):
42+
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
43+
h = h * torch.matmul(x_expert, w3[expert_idx])
44+
h = torch.matmul(h, w2[expert_idx])
45+
# h shape (tokens_per_expert(varying), dim)
46+
out_experts_splits.append(h)
47+
out = torch.cat(out_experts_splits, dim=0)
48+
49+
# side-effect code due to the usage of generate_permute_indices
50+
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
51+
else:
52+
# x shape (num_experts, tokens_per_expert, dim)
53+
h = F.silu(torch.bmm(x, w1))
54+
h = h * torch.bmm(x, w3)
55+
# out shape (num_experts, tokens_per_expert, dim)
56+
out = torch.bmm(h, w2)
57+
58+
return out
59+
60+
61+
@expert_parallel
62+
def _run_experts_grouped_mm(
63+
w1: torch.Tensor,
64+
w2: torch.Tensor,
65+
w3: torch.Tensor,
66+
x: torch.Tensor,
67+
num_tokens_per_expert: torch.Tensor | None = None,
68+
) -> torch.Tensor:
69+
if num_tokens_per_expert is not None:
70+
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
71+
# grouped mm between a 2D tensor and a 3D tensor
72+
assert x.dim() == 2
73+
else:
74+
offsets = None
75+
# fall back to regular bmm between 3D tensors
76+
assert x.dim() == 3
77+
78+
h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
79+
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
80+
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)
81+
82+
return out
83+
84+
1685
class GroupedExperts(nn.Module):
1786
def __init__(
1887
self,
@@ -34,83 +103,14 @@ def forward(
34103
num_tokens_per_expert: torch.Tensor | None = None,
35104
) -> torch.Tensor:
36105
if self.use_grouped_mm:
37-
return GroupedExperts._run_experts_grouped_mm(
106+
return _run_experts_grouped_mm(
38107
self.w1, self.w2, self.w3, x, num_tokens_per_expert
39108
)
40109
else:
41-
return GroupedExperts._run_experts_for_loop(
110+
return _run_experts_for_loop(
42111
self.w1, self.w2, self.w3, x, num_tokens_per_expert
43112
)
44113

45-
# TODO: keeping this for-loop implementation for comparison
46-
# and readability, may remove later
47-
@expert_parallel
48-
@staticmethod
49-
def _run_experts_for_loop(
50-
w1: torch.Tensor,
51-
w2: torch.Tensor,
52-
w3: torch.Tensor,
53-
x: torch.Tensor,
54-
num_tokens_per_expert: torch.Tensor | None = None,
55-
) -> torch.Tensor:
56-
if num_tokens_per_expert is not None:
57-
# NOTE: this would incur a synchronization between device and host
58-
num_tokens_per_expert = num_tokens_per_expert.tolist()
59-
60-
# side-effect code due to the usage of generate_permute_indices
61-
num_padding = x.shape[0] - sum(num_tokens_per_expert)
62-
63-
# a tuple of tensors indexed by experts
64-
# each with shape (tokens_per_expert(varying), dim)
65-
x = torch.split(
66-
x[: sum(num_tokens_per_expert)],
67-
split_size_or_sections=num_tokens_per_expert,
68-
dim=0,
69-
)
70-
out_experts_splits = []
71-
for expert_idx, x_expert in enumerate(x):
72-
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
73-
h = h * torch.matmul(x_expert, w3[expert_idx])
74-
h = torch.matmul(h, w2[expert_idx])
75-
# h shape (tokens_per_expert(varying), dim)
76-
out_experts_splits.append(h)
77-
out = torch.cat(out_experts_splits, dim=0)
78-
79-
# side-effect code due to the usage of generate_permute_indices
80-
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
81-
else:
82-
# x shape (num_experts, tokens_per_expert, dim)
83-
h = F.silu(torch.bmm(x, w1))
84-
h = h * torch.bmm(x, w3)
85-
# out shape (num_experts, tokens_per_expert, dim)
86-
out = torch.bmm(h, w2)
87-
88-
return out
89-
90-
@expert_parallel
91-
@staticmethod
92-
def _run_experts_grouped_mm(
93-
w1: torch.Tensor,
94-
w2: torch.Tensor,
95-
w3: torch.Tensor,
96-
x: torch.Tensor,
97-
num_tokens_per_expert: torch.Tensor | None = None,
98-
) -> torch.Tensor:
99-
if num_tokens_per_expert is not None:
100-
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
101-
# grouped mm between a 2D tensor and a 3D tensor
102-
assert x.dim() == 2
103-
else:
104-
offsets = None
105-
# fall back to regular bmm between 3D tensors
106-
assert x.dim() == 3
107-
108-
h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
109-
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
110-
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)
111-
112-
return out
113-
114114
def init_weights(self, init_std: float):
115115
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
116116
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)

0 commit comments

Comments
 (0)