Skip to content

Commit b27dae7

Browse files
committed
Add option for selective op AC to filter mm shapes based on fqn
1 parent 01f4e50 commit b27dae7

File tree

3 files changed

+304
-4
lines changed

3 files changed

+304
-4
lines changed
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
import torch.nn as nn
11+
from torch.utils.flop_counter import FlopCounterMode
12+
13+
from torchtitan.config_manager import ActivationCheckpoint as ACConfig
14+
from torchtitan.models.llama3.infra.parallelize import apply_ac
15+
16+
17+
class TestModule(nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
self.layers = nn.ModuleDict({"0": TransformerBlock()})
21+
22+
def forward(self, x):
23+
return self.layers["0"](x)
24+
25+
26+
class TransformerBlock(nn.Module):
27+
def __init__(self):
28+
super().__init__()
29+
self.moe = nn.Module()
30+
self.moe.router = nn.Module()
31+
self.moe.router.gate = nn.Linear(512, 512, bias=False)
32+
self.attention = nn.Module()
33+
self.attention.wq = nn.Linear(512, 512, bias=False)
34+
self.output = nn.Linear(512, 1024, bias=False)
35+
36+
def forward(self, x):
37+
gate_out = self.moe.router.gate(x)
38+
wq_out = self.attention.wq(gate_out)
39+
final_out = self.output(wq_out)
40+
return final_out.sum()
41+
42+
43+
class TestApplyAC(unittest.TestCase):
44+
def test_flops(self):
45+
def get_bw_flops(model_fn):
46+
x = torch.randn(512, 512, requires_grad=True)
47+
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
48+
out = model_fn(x)
49+
out.backward()
50+
51+
x = torch.randn(512, 512, requires_grad=True)
52+
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
53+
out = model_fn(x)
54+
with FlopCounterMode(display=False) as mode:
55+
out.backward()
56+
return mode.get_total_flops() / (512**3 * 2)
57+
58+
# 1. No AC
59+
model_no_ac = TestModule()
60+
flops_no_ac = get_bw_flops(model_no_ac)
61+
62+
# 2. SAC
63+
# Per-op SAC's policy is to save every other mm
64+
model_selective_ac = TestModule()
65+
ac_config_no_force = ACConfig(
66+
mode="selective",
67+
selective_ac_option="op",
68+
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
69+
)
70+
apply_ac(model_selective_ac, ac_config_no_force)
71+
flops_selective_ac = get_bw_flops(model_selective_ac)
72+
73+
# 3. Per-op SAC with force recompute "moe.router.gate"
74+
# This leads to two mms being recomputed since they share the same shape!
75+
model_with_force_first = TestModule()
76+
ac_config_with_force_first = ACConfig(
77+
mode="selective",
78+
selective_ac_option="op",
79+
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
80+
)
81+
apply_ac(model_with_force_first, ac_config_with_force_first)
82+
flops_with_force_first = get_bw_flops(model_with_force_first)
83+
84+
# 4. Per-op SAC with force recompute "output"
85+
model_with_force_last = TestModule()
86+
ac_config_with_force_last = ACConfig(
87+
mode="selective",
88+
selective_ac_option="op",
89+
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
90+
)
91+
apply_ac(model_with_force_last, ac_config_with_force_last)
92+
flops_with_force_last = get_bw_flops(model_with_force_last)
93+
94+
# 5. Full AC
95+
model_with_full_ac = TestModule()
96+
ac_config_full_ac = ACConfig(
97+
mode="full",
98+
)
99+
apply_ac(model_with_full_ac, ac_config_full_ac)
100+
flops_full_ac = get_bw_flops(model_with_full_ac)
101+
102+
self.assertEqual(flops_no_ac, 8.0)
103+
self.assertEqual(flops_selective_ac, 9.0)
104+
self.assertEqual(flops_with_force_first, 10.0)
105+
self.assertEqual(flops_with_force_last, 11.0)
106+
self.assertEqual(flops_full_ac, 12.0)
107+
108+
def test_mem(self):
109+
if not torch.cuda.is_available():
110+
raise unittest.SkipTest("CUDA is unavailable")
111+
112+
def get_act_mem(model_fn):
113+
x = torch.randn(512, 512, requires_grad=True, device="cuda")
114+
out = model_fn(x)
115+
out.backward()
116+
start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]
117+
118+
out = model_fn(x)
119+
cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]
120+
act_mem = (cur_mem - start_mem) / (1024 * 1024) # → MB
121+
out.backward()
122+
return act_mem
123+
124+
# 1. No AC
125+
model_no_ac = TestModule().cuda()
126+
mem_no_ac = get_act_mem(model_no_ac)
127+
128+
# 2. SAC
129+
# Per-op SAC's policy is to save every other mm
130+
model_selective_ac = TestModule().cuda()
131+
ac_config_no_force = ACConfig(
132+
mode="selective",
133+
selective_ac_option="op",
134+
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
135+
)
136+
apply_ac(model_selective_ac, ac_config_no_force)
137+
mem_selective_ac = get_act_mem(model_selective_ac)
138+
139+
# 3. Per-op SAC with force recompute "moe.router.gate"
140+
# This leads to two mms being recomputed since they share the same shape!
141+
model_with_force_first = TestModule().cuda()
142+
ac_config_with_force_first = ACConfig(
143+
mode="selective",
144+
selective_ac_option="op",
145+
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
146+
)
147+
apply_ac(model_with_force_first, ac_config_with_force_first)
148+
mem_with_force_first = get_act_mem(model_with_force_first)
149+
150+
# 4. Per-op SAC with force recompute "output"
151+
model_with_force_last = TestModule().cuda()
152+
ac_config_with_force_last = ACConfig(
153+
mode="selective",
154+
selective_ac_option="op",
155+
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
156+
)
157+
apply_ac(model_with_force_last, ac_config_with_force_last)
158+
mem_with_force_last = get_act_mem(model_with_force_last)
159+
160+
# 5. Full AC
161+
model_with_full_ac = TestModule().cuda()
162+
ac_config_full_ac = ACConfig(
163+
mode="full",
164+
)
165+
apply_ac(model_with_full_ac, ac_config_full_ac)
166+
mem_full_ac = get_act_mem(model_with_full_ac)
167+
168+
self.assertEqual(mem_no_ac, 2.0)
169+
self.assertEqual(mem_selective_ac, 3.0)
170+
self.assertEqual(mem_with_force_first, 2.0)
171+
self.assertEqual(mem_with_force_last, 1.0)
172+
self.assertEqual(mem_full_ac, 0.0)
173+
# Note: SAC > no-AC here because it unnecessarily saves "output"
174+
# even that is not needed for recomputaion and output is double
175+
# the size of the other two mms.
176+
177+
def test_correctness(self):
178+
model_no_ac = TestModule()
179+
180+
model_selective_ac = TestModule()
181+
model_selective_ac.load_state_dict(model_no_ac.state_dict())
182+
apply_ac(
183+
model_selective_ac,
184+
ACConfig(
185+
mode="selective",
186+
selective_ac_option="op",
187+
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
188+
),
189+
)
190+
model_force_first = TestModule()
191+
model_force_first.load_state_dict(model_no_ac.state_dict())
192+
apply_ac(
193+
model_force_first,
194+
ACConfig(
195+
mode="selective",
196+
selective_ac_option="op",
197+
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
198+
),
199+
)
200+
201+
model_force_last = TestModule()
202+
model_force_last.load_state_dict(model_no_ac.state_dict())
203+
apply_ac(
204+
model_force_last,
205+
ACConfig(
206+
mode="selective",
207+
selective_ac_option="op",
208+
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
209+
),
210+
)
211+
212+
def run_fwd_bwd(model, batch):
213+
model.zero_grad(set_to_none=True)
214+
xin = batch.clone().detach().requires_grad_(True)
215+
out = model(xin) # scalar
216+
out.backward()
217+
218+
grad_in = xin.grad.detach().clone()
219+
grad_params = [
220+
p.grad.detach().clone() if isinstance(p.grad, torch.Tensor) else None
221+
for p in model.parameters()
222+
]
223+
return out.detach(), grad_in, grad_params
224+
225+
batch = torch.randn(64, 512)
226+
227+
out_ref, gin_ref, gparams_ref = run_fwd_bwd(model_no_ac, batch)
228+
out_sel, gin_sel, gparams_sel = run_fwd_bwd(model_selective_ac, batch)
229+
out_f1, gin_f1, gparams_f1 = run_fwd_bwd(model_force_first, batch)
230+
out_fl, gin_fl, gparams_fl = run_fwd_bwd(model_force_last, batch)
231+
232+
for other_out in (out_sel, out_f1, out_fl):
233+
torch.testing.assert_close(out_ref, other_out)
234+
235+
for other_gin in (gin_sel, gin_f1, gin_fl):
236+
torch.testing.assert_close(gin_ref, other_gin)
237+
238+
for g_ref, g_sel, g_f1, g_fl in zip(
239+
gparams_ref, gparams_sel, gparams_f1, gparams_fl
240+
):
241+
# Skip wrapper / missing grads
242+
if not (
243+
torch.is_tensor(g_ref)
244+
and torch.is_tensor(g_sel)
245+
and torch.is_tensor(g_f1)
246+
and torch.is_tensor(g_fl)
247+
):
248+
continue
249+
250+
torch.testing.assert_close(g_ref, g_sel)
251+
torch.testing.assert_close(g_ref, g_f1)
252+
torch.testing.assert_close(g_ref, g_fl)
253+
254+
255+
if __name__ == "__main__":
256+
unittest.main()

torchtitan/config_manager.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,20 @@ class ActivationCheckpoint:
487487
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
488488
"""
489489

490+
per_op_sac_force_recompute_mm_shapes_by_fqns: list[str] = field(
491+
default_factory=lambda: ["moe.router.gate"]
492+
)
493+
"""
494+
When per-op selective ac is used, this list of fully qualified names (relative
495+
to the module at which AC is applied) is used to determine which mm shapes to
496+
force recompute, rather than being considered by rest of the sac policy, e.g
497+
save every other mm. Only nn.Linear modules are supported today.
498+
499+
Note: this config applies to mms not limited to those matching the specified
500+
fqns, e.g. if "moe.router.gate", corresponding to Linear(in, out), is specified,
501+
ANY mm with shape matching (*, in) x (in, out) will be force recomputed.
502+
"""
503+
490504

491505
@dataclass
492506
class Float8:

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
SequenceParallel,
2828
)
2929

30-
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
30+
from torchtitan.config_manager import (
31+
ActivationCheckpoint as ACConfig,
32+
JobConfig,
33+
TORCH_DTYPE_MAP,
34+
)
3135
from torchtitan.distributed import ParallelDims
3236
from torchtitan.tools.logging import logger
3337

@@ -237,7 +241,9 @@ def apply_tp(
237241
}
238242

239243

240-
def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
244+
def _apply_ac_to_transformer_block(
245+
module: nn.Module, ac_config: ACConfig, base_fqn: str
246+
):
241247
valid_ac_modes = ("full", "selective")
242248
if ac_config.mode not in valid_ac_modes:
243249
raise ValueError(
@@ -261,11 +267,33 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
261267
create_selective_checkpoint_contexts,
262268
)
263269

270+
mm_recompute_shapes = set()
271+
if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0:
272+
for module_fqn, submod in module.named_modules():
273+
fqn = f"{base_fqn}.{module_fqn}"
274+
if not any(
275+
filter_fqn in fqn
276+
for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns
277+
):
278+
continue
279+
if not isinstance(submod, nn.Linear):
280+
raise ValueError(
281+
"per_op_sac_force_recompute_mm_shapes_by_fqns expected to match "
282+
f"a nn.Linear, but got: {submod}"
283+
)
284+
out_f, in_f = submod.weight.shape
285+
mm_recompute_shapes.add((in_f, out_f))
286+
logger.debug(
287+
f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}"
288+
)
289+
264290
def _get_custom_policy(meta):
265291
def _custom_policy(ctx, func, *args, **kwargs):
266292
mode = "recompute" if ctx.is_recompute else "forward"
267293
mm_count_key = f"{mode}_mm_count"
268294
if func == torch.ops.aten.mm.default:
295+
if args[1].shape in mm_recompute_shapes:
296+
return CheckpointPolicy.PREFER_RECOMPUTE
269297
meta[mm_count_key] += 1
270298
# Saves output of all compute ops, except every second mm
271299
to_save = func in _save_list and not (
@@ -299,10 +327,12 @@ def selective_checkpointing_context_fn():
299327
return module
300328

301329

302-
def apply_ac(model: nn.Module, ac_config):
330+
def apply_ac(model: nn.Module, ac_config: ACConfig):
303331
"""Apply activation checkpointing to the model."""
304332
for layer_id, transformer_block in model.layers.named_children():
305-
transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config)
333+
transformer_block = _apply_ac_to_transformer_block(
334+
transformer_block, ac_config, f"layers.{layer_id}"
335+
)
306336
model.layers.register_module(layer_id, transformer_block)
307337

308338
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")

0 commit comments

Comments
 (0)