Skip to content

Commit 7325e54

Browse files
committed
Add option to exclude low flop mms from every-other-mm sac policy
ghstack-source-id: fabf2e0 Pull-Request-resolved: #1372
1 parent 01f4e50 commit 7325e54

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

torchtitan/config_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,12 @@ class ActivationCheckpoint:
486486
Selective activation checkpointing options ['int', 'op'].
487487
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
488488
"""
489+
selective_op_ac_mm_flops_threshold: int = 0
490+
"""
491+
When selective_ac_option is 'op', this threshold is used to determine whether to
492+
save a given mm, e.g. 1e5 means excluding mms flops < 1e5, and then saving
493+
every other mm from the remaining mms.
494+
"""
489495

490496

491497
@dataclass

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,15 @@ def _get_custom_policy(meta):
265265
def _custom_policy(ctx, func, *args, **kwargs):
266266
mode = "recompute" if ctx.is_recompute else "forward"
267267
mm_count_key = f"{mode}_mm_count"
268+
269+
if func == torch.ops.aten.mm.default:
270+
m, k = args[0].shape
271+
k2, n = args[1].shape
272+
assert k == k2
273+
flops = m * n * 2 * k
274+
if flops < ac_config.selective_op_ac_mm_flops_threshold:
275+
return CheckpointPolicy.PREFER_RECOMPUTE
276+
268277
if func == torch.ops.aten.mm.default:
269278
meta[mm_count_key] += 1
270279
# Saves output of all compute ops, except every second mm

0 commit comments

Comments
 (0)