Skip to content

Commit bfb3a32

Browse files
committed
Add option for selective op AC to filter mm shapes based on fqn
1 parent 01f4e50 commit bfb3a32

File tree

3 files changed

+263
-4
lines changed

3 files changed

+263
-4
lines changed
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
import torch.nn as nn
11+
from torch.utils.flop_counter import FlopCounterMode
12+
13+
from torchtitan.config_manager import ActivationCheckpoint
14+
from torchtitan.models.llama3.infra.parallelize import apply_ac
15+
16+
17+
class TestModule(nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
self.layers = nn.ModuleDict({"0": TransformerBlock()})
21+
22+
def forward(self, x):
23+
return self.layers["0"](x)
24+
25+
26+
class TransformerBlock(nn.Module):
27+
def __init__(self):
28+
super().__init__()
29+
self.moe = nn.Module()
30+
self.moe.router = nn.Module()
31+
self.moe.router.gate = nn.Linear(512, 512, bias=False)
32+
self.attention = nn.Module()
33+
self.attention.wq = nn.Linear(512, 512, bias=False)
34+
self.output = nn.Linear(512, 1024, bias=False)
35+
36+
def forward(self, x):
37+
gate_out = self.moe.router.gate(x)
38+
wq_out = self.attention.wq(gate_out)
39+
final_out = self.output(wq_out)
40+
return final_out.sum()
41+
42+
43+
class TestApplyAC(unittest.TestCase):
44+
def test_flops(self):
45+
if not torch.cuda.is_available():
46+
raise unittest.SkipTest("CUDA is unavailable")
47+
48+
def get_bw_flops(model_fn):
49+
x = torch.randn(512, 512, requires_grad=True, device="cuda")
50+
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
51+
out = model_fn(x)
52+
out.backward()
53+
54+
x = torch.randn(512, 512, requires_grad=True, device="cuda")
55+
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
56+
out = model_fn(x)
57+
with FlopCounterMode(display=False) as mode:
58+
out.backward()
59+
return mode.get_total_flops() / (512**3 * 2)
60+
61+
def get_act_mem(model_fn):
62+
x = torch.randn(512, 512, requires_grad=True, device="cuda")
63+
out = model_fn(x)
64+
out.backward()
65+
start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]
66+
67+
out = model_fn(x)
68+
cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]
69+
act_mem = (cur_mem - start_mem) / (1024 * 1024) # → MB
70+
out.backward()
71+
return act_mem
72+
73+
# 1. No AC
74+
model_no_ac = TestModule().cuda()
75+
flops_no_ac = get_bw_flops(model_no_ac)
76+
mem_no_ac = get_act_mem(model_no_ac)
77+
78+
# 2. SAC
79+
# Per-op SAC's policy is to save every other mm
80+
model_selective_ac = TestModule().cuda()
81+
ac_config_no_force = ActivationCheckpoint(
82+
mode="selective",
83+
selective_ac_option="op",
84+
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
85+
)
86+
apply_ac(model_selective_ac, ac_config_no_force)
87+
flops_selective_ac = get_bw_flops(model_selective_ac)
88+
mem_selective_ac = get_act_mem(model_selective_ac)
89+
90+
# 3. Per-op SAC with force recompute "moe.router.gate"
91+
# This leads to two mms being recomputed since they share the same shape!
92+
model_with_force_first = TestModule().cuda()
93+
ac_config_with_force_first = ActivationCheckpoint(
94+
mode="selective",
95+
selective_ac_option="op",
96+
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
97+
)
98+
apply_ac(model_with_force_first, ac_config_with_force_first)
99+
flops_with_force_first = get_bw_flops(model_with_force_first)
100+
mem_with_force_first = get_act_mem(model_with_force_first)
101+
102+
# 4. Per-op SAC with force recompute "output"
103+
model_with_force_last = TestModule().cuda()
104+
ac_config_with_force_last = ActivationCheckpoint(
105+
mode="selective",
106+
selective_ac_option="op",
107+
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
108+
)
109+
apply_ac(model_with_force_last, ac_config_with_force_last)
110+
flops_with_force_last = get_bw_flops(model_with_force_last)
111+
mem_with_force_last = get_act_mem(model_with_force_last)
112+
113+
# 5. Full AC
114+
model_with_full_ac = TestModule().cuda()
115+
ac_config_full_ac = ActivationCheckpoint(
116+
mode="full",
117+
)
118+
apply_ac(model_with_full_ac, ac_config_full_ac)
119+
flops_full_ac = get_bw_flops(model_with_full_ac)
120+
mem_full_ac = get_act_mem(model_with_full_ac)
121+
122+
self.assertEqual(flops_no_ac, 8.0)
123+
self.assertEqual(flops_selective_ac, 9.0)
124+
self.assertEqual(flops_with_force_first, 10.0)
125+
self.assertEqual(flops_with_force_last, 11.0)
126+
self.assertEqual(flops_full_ac, 12.0)
127+
128+
self.assertEqual(mem_no_ac, 2.0)
129+
self.assertEqual(mem_selective_ac, 3.0)
130+
self.assertEqual(mem_with_force_first, 2.0)
131+
self.assertEqual(mem_with_force_last, 1.0)
132+
self.assertEqual(mem_full_ac, 0.0)
133+
# Note: SAC > no-AC here because it unnecessarily saves "output"
134+
# even that is not needed for recomputaion and output is double
135+
# the size of the other two mms.
136+
137+
def test_correctness(self):
138+
if not torch.cuda.is_available():
139+
raise unittest.SkipTest("CUDA is unavailable")
140+
141+
model_no_ac = TestModule().cuda()
142+
143+
model_selective_ac = TestModule().cuda()
144+
model_selective_ac.load_state_dict(model_no_ac.state_dict())
145+
apply_ac(
146+
model_selective_ac,
147+
ActivationCheckpoint(
148+
mode="selective",
149+
selective_ac_option="op",
150+
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
151+
),
152+
)
153+
model_force_first = TestModule().cuda()
154+
model_force_first.load_state_dict(model_no_ac.state_dict())
155+
apply_ac(
156+
model_force_first,
157+
ActivationCheckpoint(
158+
mode="selective",
159+
selective_ac_option="op",
160+
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
161+
),
162+
)
163+
164+
model_force_last = TestModule().cuda()
165+
model_force_last.load_state_dict(model_no_ac.state_dict())
166+
apply_ac(
167+
model_force_last,
168+
ActivationCheckpoint(
169+
mode="selective",
170+
selective_ac_option="op",
171+
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
172+
),
173+
)
174+
175+
def run_fwd_bwd(model, batch):
176+
model.zero_grad(set_to_none=True)
177+
xin = batch.clone().detach().requires_grad_(True)
178+
out = model(xin) # scalar
179+
out.backward()
180+
181+
grad_in = xin.grad.detach().clone()
182+
grad_params = [
183+
p.grad.detach().clone() if isinstance(p.grad, torch.Tensor) else None
184+
for p in model.parameters()
185+
]
186+
return out.detach(), grad_in, grad_params
187+
188+
batch = torch.randn(64, 512, device="cuda")
189+
190+
out_ref, gin_ref, gparams_ref = run_fwd_bwd(model_no_ac, batch)
191+
out_sel, gin_sel, gparams_sel = run_fwd_bwd(model_selective_ac, batch)
192+
out_f1, gin_f1, gparams_f1 = run_fwd_bwd(model_force_first, batch)
193+
out_fl, gin_fl, gparams_fl = run_fwd_bwd(model_force_last, batch)
194+
195+
for other_out in (out_sel, out_f1, out_fl):
196+
torch.testing.assert_close(out_ref, other_out)
197+
198+
for other_gin in (gin_sel, gin_f1, gin_fl):
199+
torch.testing.assert_close(gin_ref, other_gin)
200+
201+
for g_ref, g_sel, g_f1, g_fl in zip(
202+
gparams_ref, gparams_sel, gparams_f1, gparams_fl
203+
):
204+
# Skip wrapper / missing grads
205+
if not (
206+
torch.is_tensor(g_ref)
207+
and torch.is_tensor(g_sel)
208+
and torch.is_tensor(g_f1)
209+
and torch.is_tensor(g_fl)
210+
):
211+
continue
212+
213+
torch.testing.assert_close(g_ref, g_sel)
214+
torch.testing.assert_close(g_ref, g_f1)
215+
torch.testing.assert_close(g_ref, g_fl)
216+
217+
218+
if __name__ == "__main__":
219+
unittest.main()

torchtitan/config_manager.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,20 @@ class ActivationCheckpoint:
487487
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
488488
"""
489489

490+
per_op_sac_force_recompute_mm_shapes_by_fqns: list[str] = field(
491+
default_factory=lambda: ["moe.router.gate"]
492+
)
493+
"""
494+
When per-op selective ac is used, this list of fully qualified names (relative
495+
to the module at which AC is applied) is used to determine which mm shapes to
496+
force recompute, rather than being considered by rest of the sac policy, e.g
497+
save every other mm. Only nn.Linear modules are supported today.
498+
499+
Note: this config applies to mms not limited to those matching the specified
500+
fqns, e.g. if "moe.router.gate", corresponding to Linear(in, out), is specified,
501+
ANY mm with shape matching (*, in) x (in, out) will be force recomputed.
502+
"""
503+
490504

491505
@dataclass
492506
class Float8:

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
SequenceParallel,
2828
)
2929

30-
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
30+
from torchtitan.config_manager import ActivationCheckpoint, JobConfig, TORCH_DTYPE_MAP
3131
from torchtitan.distributed import ParallelDims
3232
from torchtitan.tools.logging import logger
3333

@@ -237,7 +237,9 @@ def apply_tp(
237237
}
238238

239239

240-
def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
240+
def _apply_ac_to_transformer_block(
241+
module: nn.Module, ac_config: ActivationCheckpoint, base_fqn: str
242+
):
241243
valid_ac_modes = ("full", "selective")
242244
if ac_config.mode not in valid_ac_modes:
243245
raise ValueError(
@@ -261,11 +263,33 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
261263
create_selective_checkpoint_contexts,
262264
)
263265

266+
mm_recompute_shapes = set()
267+
if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0:
268+
for module_fqn, submod in module.named_modules():
269+
fqn = f"{base_fqn}.{module_fqn}"
270+
if not any(
271+
filter_fqn in fqn
272+
for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns
273+
):
274+
continue
275+
if not isinstance(submod, nn.Linear):
276+
raise ValueError(
277+
"per_op_sac_force_recompute_mm_shapes_by_fqns expected to match "
278+
f"a nn.Linear, but got: {submod}"
279+
)
280+
out_f, in_f = submod.weight.shape
281+
mm_recompute_shapes.add((in_f, out_f))
282+
logger.debug(
283+
f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}"
284+
)
285+
264286
def _get_custom_policy(meta):
265287
def _custom_policy(ctx, func, *args, **kwargs):
266288
mode = "recompute" if ctx.is_recompute else "forward"
267289
mm_count_key = f"{mode}_mm_count"
268290
if func == torch.ops.aten.mm.default:
291+
if args[1].shape in mm_recompute_shapes:
292+
return CheckpointPolicy.PREFER_RECOMPUTE
269293
meta[mm_count_key] += 1
270294
# Saves output of all compute ops, except every second mm
271295
to_save = func in _save_list and not (
@@ -299,10 +323,12 @@ def selective_checkpointing_context_fn():
299323
return module
300324

301325

302-
def apply_ac(model: nn.Module, ac_config):
326+
def apply_ac(model: nn.Module, ac_config: ActivationCheckpoint):
303327
"""Apply activation checkpointing to the model."""
304328
for layer_id, transformer_block in model.layers.named_children():
305-
transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config)
329+
transformer_block = _apply_ac_to_transformer_block(
330+
transformer_block, ac_config, f"layers.{layer_id}"
331+
)
306332
model.layers.register_module(layer_id, transformer_block)
307333

308334
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")

0 commit comments

Comments
 (0)