Skip to content

Commit 06725f1

Browse files
committed
compile, with pytorch/pytorch 26807dcf277
1 parent 27e3ad8 commit 06725f1

File tree

2 files changed

+89
-77
lines changed

2 files changed

+89
-77
lines changed

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,7 @@
2121
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
2222
from torchtitan.distributed import ParallelDims
2323

24-
from torchtitan.models.llama3.infra.parallelize import (
25-
apply_ac,
26-
apply_compile,
27-
apply_ddp,
28-
)
24+
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
2925
from torchtitan.tools.logging import logger
3026

3127
from .expert_parallel import (
@@ -36,6 +32,20 @@
3632
)
3733

3834

35+
def apply_compile(model: nn.Module):
36+
"""
37+
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
38+
repeated structure. Alternatively one can compile the whole model (after applying DP).
39+
"""
40+
torch._dynamo.config.fail_on_recompile_limit_hit = True
41+
for layer_id, transformer_block in model.layers.named_children():
42+
# NOTE: we allow graph breaks on FSDP hooks for MoE experts, `see set_fullgraph(False)` in moe.py
43+
transformer_block = torch.compile(transformer_block, fullgraph=True)
44+
model.layers.register_module(layer_id, transformer_block)
45+
46+
logger.info("Compiling each TransformerBlock with torch.compile")
47+
48+
3949
def parallelize_llama(
4050
model: nn.Module,
4151
parallel_dims: ParallelDims,

torchtitan/experiments/llama4/model/moe.py

Lines changed: 74 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,75 @@
1313
from .args import TransformerModelArgs
1414

1515

16+
# TODO: keeping this for-loop implementation for comparison
17+
# and readability, may remove later
18+
@expert_parallel
19+
def _run_experts_for_loop(
20+
w1: torch.Tensor,
21+
w2: torch.Tensor,
22+
w3: torch.Tensor,
23+
x: torch.Tensor,
24+
num_tokens_per_expert: torch.Tensor | None = None,
25+
) -> torch.Tensor:
26+
if num_tokens_per_expert is not None:
27+
# NOTE: this would incur a synchronization between device and host
28+
num_tokens_per_expert = num_tokens_per_expert.tolist()
29+
30+
# side-effect code due to the usage of generate_permute_indices
31+
num_padding = x.shape[0] - sum(num_tokens_per_expert)
32+
33+
# a tuple of tensors indexed by experts
34+
# each with shape (tokens_per_expert(varying), dim)
35+
x = torch.split(
36+
x[: sum(num_tokens_per_expert)],
37+
split_size_or_sections=num_tokens_per_expert,
38+
dim=0,
39+
)
40+
out_experts_splits = []
41+
for expert_idx, x_expert in enumerate(x):
42+
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
43+
h = h * torch.matmul(x_expert, w3[expert_idx])
44+
h = torch.matmul(h, w2[expert_idx])
45+
# h shape (tokens_per_expert(varying), dim)
46+
out_experts_splits.append(h)
47+
out = torch.cat(out_experts_splits, dim=0)
48+
49+
# side-effect code due to the usage of generate_permute_indices
50+
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
51+
else:
52+
# x shape (num_experts, tokens_per_expert, dim)
53+
h = F.silu(torch.bmm(x, w1))
54+
h = h * torch.bmm(x, w3)
55+
# out shape (num_experts, tokens_per_expert, dim)
56+
out = torch.bmm(h, w2)
57+
58+
return out
59+
60+
61+
@expert_parallel
62+
def _run_experts_grouped_mm(
63+
w1: torch.Tensor,
64+
w2: torch.Tensor,
65+
w3: torch.Tensor,
66+
x: torch.Tensor,
67+
num_tokens_per_expert: torch.Tensor | None = None,
68+
) -> torch.Tensor:
69+
if num_tokens_per_expert is not None:
70+
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
71+
# grouped mm between a 2D tensor and a 3D tensor
72+
assert x.dim() == 2
73+
else:
74+
offsets = None
75+
# fall back to regular bmm between 3D tensors
76+
assert x.dim() == 3
77+
78+
h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
79+
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
80+
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)
81+
82+
return out
83+
84+
1685
class GroupedExperts(nn.Module):
1786
def __init__(
1887
self,
@@ -28,89 +97,21 @@ def __init__(
2897
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
2998
self.use_grouped_mm = use_grouped_mm
3099

100+
@torch._dynamo.set_fullgraph(True)
31101
def forward(
32102
self,
33103
x: torch.Tensor,
34104
num_tokens_per_expert: torch.Tensor | None = None,
35105
) -> torch.Tensor:
36106
if self.use_grouped_mm:
37-
return GroupedExperts._run_experts_grouped_mm(
107+
return _run_experts_grouped_mm(
38108
self.w1, self.w2, self.w3, x, num_tokens_per_expert
39109
)
40110
else:
41-
return GroupedExperts._run_experts_for_loop(
111+
return _run_experts_for_loop(
42112
self.w1, self.w2, self.w3, x, num_tokens_per_expert
43113
)
44114

45-
# TODO: keeping this for-loop implementation for comparison
46-
# and readability, may remove later
47-
@expert_parallel
48-
@staticmethod
49-
def _run_experts_for_loop(
50-
w1: torch.Tensor,
51-
w2: torch.Tensor,
52-
w3: torch.Tensor,
53-
x: torch.Tensor,
54-
num_tokens_per_expert: torch.Tensor | None = None,
55-
) -> torch.Tensor:
56-
if num_tokens_per_expert is not None:
57-
# NOTE: this would incur a synchronization between device and host
58-
num_tokens_per_expert = num_tokens_per_expert.tolist()
59-
60-
# side-effect code due to the usage of generate_permute_indices
61-
num_padding = x.shape[0] - sum(num_tokens_per_expert)
62-
63-
# a tuple of tensors indexed by experts
64-
# each with shape (tokens_per_expert(varying), dim)
65-
x = torch.split(
66-
x[: sum(num_tokens_per_expert)],
67-
split_size_or_sections=num_tokens_per_expert,
68-
dim=0,
69-
)
70-
out_experts_splits = []
71-
for expert_idx, x_expert in enumerate(x):
72-
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
73-
h = h * torch.matmul(x_expert, w3[expert_idx])
74-
h = torch.matmul(h, w2[expert_idx])
75-
# h shape (tokens_per_expert(varying), dim)
76-
out_experts_splits.append(h)
77-
out = torch.cat(out_experts_splits, dim=0)
78-
79-
# side-effect code due to the usage of generate_permute_indices
80-
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
81-
else:
82-
# x shape (num_experts, tokens_per_expert, dim)
83-
h = F.silu(torch.bmm(x, w1))
84-
h = h * torch.bmm(x, w3)
85-
# out shape (num_experts, tokens_per_expert, dim)
86-
out = torch.bmm(h, w2)
87-
88-
return out
89-
90-
@expert_parallel
91-
@staticmethod
92-
def _run_experts_grouped_mm(
93-
w1: torch.Tensor,
94-
w2: torch.Tensor,
95-
w3: torch.Tensor,
96-
x: torch.Tensor,
97-
num_tokens_per_expert: torch.Tensor | None = None,
98-
) -> torch.Tensor:
99-
if num_tokens_per_expert is not None:
100-
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
101-
# grouped mm between a 2D tensor and a 3D tensor
102-
assert x.dim() == 2
103-
else:
104-
offsets = None
105-
# fall back to regular bmm between 3D tensors
106-
assert x.dim() == 3
107-
108-
h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
109-
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
110-
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)
111-
112-
return out
113-
114115
def init_weights(self, init_std: float):
115116
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
116117
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
@@ -297,7 +298,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
297298
)
298299

299300
# shape (bs*slen*top_k, dim)
300-
routed_output = self.experts(routed_input, num_tokens_per_expert)
301+
with torch._dynamo.set_fullgraph(False):
302+
routed_output = self.experts(routed_input, num_tokens_per_expert)
301303

302304
# shared expert
303305
if self.shared_expert is not None:

0 commit comments

Comments
 (0)