Skip to content

Commit 948a807

Browse files
committed
Add option to exclude low flop mms from every-other-mm sac policy
ghstack-source-id: 9c99bff Pull-Request-resolved: #1372
1 parent 01f4e50 commit 948a807

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

torchtitan/config_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,15 @@ class ActivationCheckpoint:
486486
Selective activation checkpointing options ['int', 'op'].
487487
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
488488
"""
489+
selective_op_ac_mm_flops_threshold: int = 0
490+
"""
491+
When selective_ac_option is 'op', this threshold is used to determine whether to
492+
apply save a given mm.
493+
494+
For example:
495+
- 0 means no threshold; every other mm is saved
496+
- 1e5 means every other mm is saved, excluding mm with flops > 1e5.
497+
"""
489498

490499

491500
@dataclass

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,15 @@ def _get_custom_policy(meta):
265265
def _custom_policy(ctx, func, *args, **kwargs):
266266
mode = "recompute" if ctx.is_recompute else "forward"
267267
mm_count_key = f"{mode}_mm_count"
268+
269+
if func == torch.ops.aten.mm.default:
270+
m, k = args[0].shape
271+
k2, n = args[1].shape
272+
assert k == k2
273+
flops = m * n * 2 * k
274+
if flops < ac_config.selective_op_ac_mm_flops_threshold:
275+
return CheckpointPolicy.PREFER_RECOMPUTE
276+
268277
if func == torch.ops.aten.mm.default:
269278
meta[mm_count_key] += 1
270279
# Saves output of all compute ops, except every second mm

0 commit comments

Comments
 (0)