Skip to content

Add option to exclude low flop mms from every-other-mm sac policy #1372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,12 @@ class ActivationCheckpoint:
Selective activation checkpointing options ['int', 'op'].
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
"""
selective_op_ac_mm_flops_threshold: int = 0
"""
When selective_ac_option is 'op', this threshold is used to determine whether to
save a given mm, e.g. 1e5 means excluding mms flops < 1e5, and then saving
every other mm from the remaining mms.
"""


@dataclass
Expand Down
9 changes: 9 additions & 0 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,15 @@ def _get_custom_policy(meta):
def _custom_policy(ctx, func, *args, **kwargs):
mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"

if func == torch.ops.aten.mm.default:
m, k = args[0].shape
k2, n = args[1].shape
assert k == k2
flops = m * n * 2 * k
if flops < ac_config.selective_op_ac_mm_flops_threshold:
return CheckpointPolicy.PREFER_RECOMPUTE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a strong reason we use "always recompute" instead of "always save" for these small matmuls?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm guessing because matmul FLOPs required grows cubically / O(n^3), but the output activation grows only quadratically / O(n^2), so when n is small the FLOPS required to recompute are relatively small compared to the size of the output activation. In contrast, when n is larger, the FLOPs to recompute has grown cubically but memory saved has only grown quadratically, so the trade-off for saving instead of recomputing becomes more favorable.

If so, I think this change is a good idea, as long as it is configurable and has a sensible default that is documented clearly.

Copy link
Contributor Author

@soulitzer soulitzer Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a strong reason we use "always recompute" instead of "always save" for these small matmuls?

Good question, relative to non-matmuls, perhaps small matmuls still do a disproportionate amount of compute relative to output size, so there could definitely be benefit to saving as well.

Although as Daniel mentions, relative to large matmuls, small matmuls indeed have a more favorable memory/compute trade off for recomputing, so intuitively it might be more pareto optimal to recompute, e.g.

If we were to follow a greedy heuristic of only saving the op with the next best trade-off, since we're currently only saving every other large matmul, the most profitable next tensor to save should be another large matmul, rather spending that same amount of memory saving a bunch of smaller ones.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intuitively it might be more pareto optimal to recompute, e.g.
If we were to follow a greedy heuristic of only saving the op with the next best trade-off, since we're currently only saving every other large matmul, the most profitable next tensor to save should be another large matmul

oh this is great argument, I got the idea.


if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
Expand Down