Skip to content

Commit 3c4d97d

Browse files
committed
Add option for selective op AC to filter mm shapes based on fqn
1 parent 01f4e50 commit 3c4d97d

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

torchtitan/config_manager.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,20 @@ class ActivationCheckpoint:
487487
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
488488
"""
489489

490+
selective_op_ac_force_recompute_mm_shapes_by_fqns: list[str] = field(
491+
default_factory=lambda: []
492+
)
493+
"""
494+
When per-op selective ac is used, this list of fully qualified names (relative
495+
to the module at which AC is applied) is used to determine which mm shapes to
496+
force recompute, rather than being considered by rest of the sac policy, e.g
497+
save every other mm. Only nn.Linear modules are supported today.
498+
499+
Note: this config applies to mms not limited to those matching the specified
500+
fqns, e.g. if "moe.router.gate", corresponding to Linear(in, out), is specified,
501+
ANY mm with shape matching (*, in) x (in, out) will be force recomputed.
502+
"""
503+
490504

491505
@dataclass
492506
class Float8:

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,32 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
261261
create_selective_checkpoint_contexts,
262262
)
263263

264+
mm_recompute_shapes = set()
265+
if len(ac_config.selective_op_ac_force_recompute_mm_shapes_by_fqns) > 0:
266+
for fqn, submod in module.named_modules():
267+
if (
268+
fqn
269+
not in ac_config.selective_op_ac_force_recompute_mm_shapes_by_fqns
270+
):
271+
continue
272+
if not isinstance(submod, nn.Linear):
273+
raise ValueError(
274+
"selective_op_ac_force_recompute_mm_shapes_by_fqns expected to match "
275+
f"a nn.Linear, but got: {submod}"
276+
)
277+
out_f, in_f = submod.weight.shape
278+
mm_recompute_shapes.add((in_f, out_f))
279+
logger.debug(
280+
f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}"
281+
)
282+
264283
def _get_custom_policy(meta):
265284
def _custom_policy(ctx, func, *args, **kwargs):
266285
mode = "recompute" if ctx.is_recompute else "forward"
267286
mm_count_key = f"{mode}_mm_count"
268287
if func == torch.ops.aten.mm.default:
288+
if args[1].shape in mm_recompute_shapes:
289+
return CheckpointPolicy.PREFER_RECOMPUTE
269290
meta[mm_count_key] += 1
270291
# Saves output of all compute ops, except every second mm
271292
to_save = func in _save_list and not (

0 commit comments

Comments
 (0)