Skip to content

Add option to exclude low flop mms from every-other-mm sac policy #1372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: gh/soulitzer/2/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,18 @@ class ActivationCheckpoint:
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
"""

selective_op_ac_mm_flops_threshold: int = 0
"""
When selective_ac_option is 'op', this threshold is used to determine whether to
save a given mm, e.g. 1e5 means excluding mms flops < 1e5, and then saving
every other mm from the remaining mms.
"""

log_mm_flops: bool = False
"""
Whether to log the distribution of mm flops. This can be useful for determining
the appropriate threshold for selective_op_ac_mm_flops_threshold.
"""

@dataclass
class Float8:
Expand Down
6 changes: 4 additions & 2 deletions torchtitan/experiments/llama4/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "none" # ["none", "selective", "full"]
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy
mode = "selective" # ["none", "selective", "full"]
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
selective_op_ac_mm_flops_threshold = 0 # checking if everything is recomputed
log_mm_flops = true
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, didn't mean to include this


[float8]
enable_fsdp_float8_all_gather = false
Expand Down
64 changes: 62 additions & 2 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# training techniques (e.g. activation checkpointing and compile) to the Llama model.

from collections import defaultdict
from decimal import Decimal
import functools

import torch
import torch.nn as nn
Expand Down Expand Up @@ -237,6 +239,28 @@ def apply_tp(
}


def _format_mm_flops_table(entries):
header = ("MM Shape", "FLOPs", "Count")
rows = [header] + [(k[0], k[1], str(v)) for k, v in entries.items()]
col0 = max(len(r[0]) for r in rows)
col1 = max(len(r[1]) for r in rows)
col2 = max(len(r[2]) for r in rows)
lines = [
f"| {'MM Shape'.ljust(col0)} | {'FLOPs'.ljust(col1)} | {'Count'.ljust(col2)} |",
f"| {'-' * col0} | {'-' * col1} | {'-' * col2} |",
]
for s, fl, cnt in rows[1:]:
lines.append(f"| {s.ljust(col0)} | {fl.ljust(col1)} | {cnt.ljust(col2)} |")
return "\n".join(lines)


def _wrap_with_disable_early_stop(fn):
def inner(*args, **kwargs):
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
return fn(*args, **kwargs)
return inner


def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
valid_ac_modes = ("full", "selective")
if ac_config.mode not in valid_ac_modes:
Expand Down Expand Up @@ -264,12 +288,38 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
def _get_custom_policy(meta):
def _custom_policy(ctx, func, *args, **kwargs):
mode = "recompute" if ctx.is_recompute else "forward"

mm_count_key_filtered = f"{mode}_mm_count_filtered"
mm_count_key = f"{mode}_mm_count"

if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1

m, k = args[0].shape
k2, n = args[1].shape
assert k == k2
flops = m * n * 2 * k

if ac_config.log_mm_flops and ctx.is_recompute:
shape_str = f"({m}x{k}) x ({k2}x{n})"
flops_str = f"{Decimal(flops):.2E}"
key = (shape_str, flops_str)
meta[key] += 1
if meta["recompute_mm_count"] == meta["forward_mm_count"]:
table = _format_mm_flops_table({k:v for k,v in meta.items() if "mm_count" not in k})
logger.info("\n%s", table)

# Filter out ops below are certain flop threshold. See discussion for why we
# recompute instead of save here:
# https://github.com/pytorch/torchtitan/pull/1372#discussion_r2193722200
if flops < ac_config.selective_op_ac_mm_flops_threshold:
return CheckpointPolicy.PREFER_RECOMPUTE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a strong reason we use "always recompute" instead of "always save" for these small matmuls?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm guessing because matmul FLOPs required grows cubically / O(n^3), but the output activation grows only quadratically / O(n^2), so when n is small the FLOPS required to recompute are relatively small compared to the size of the output activation. In contrast, when n is larger, the FLOPs to recompute has grown cubically but memory saved has only grown quadratically, so the trade-off for saving instead of recomputing becomes more favorable.

If so, I think this change is a good idea, as long as it is configurable and has a sensible default that is documented clearly.

Copy link
Author

@soulitzer soulitzer Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a strong reason we use "always recompute" instead of "always save" for these small matmuls?

Good question, relative to non-matmuls, perhaps small matmuls still do a disproportionate amount of compute relative to output size, so there could definitely be benefit to saving as well.

Although as Daniel mentions, relative to large matmuls, small matmuls indeed have a more favorable memory/compute trade off for recomputing, so intuitively it might be more pareto optimal to recompute, e.g.

If we were to follow a greedy heuristic of only saving the op with the next best trade-off, since we're currently only saving every other large matmul, the most profitable next tensor to save should be another large matmul, rather spending that same amount of memory saving a bunch of smaller ones.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intuitively it might be more pareto optimal to recompute, e.g.
If we were to follow a greedy heuristic of only saving the op with the next best trade-off, since we're currently only saving every other large matmul, the most profitable next tensor to save should be another large matmul

oh this is great argument, I got the idea.


meta[mm_count_key_filtered] += 1

# Saves output of all compute ops, except every second mm
to_save = func in _save_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
func == torch.ops.aten.mm.default and meta[mm_count_key_filtered] % 2 == 0
)
return (
CheckpointPolicy.MUST_SAVE
Expand All @@ -283,9 +333,17 @@ def selective_checkpointing_context_fn():
meta = defaultdict(int)
return create_selective_checkpoint_contexts(_get_custom_policy(meta))

checkpoint_fn = functools.partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if ac_config.log_mm_flops:
# If early-stop is enabled, fewer mm are recomputed than in forward. Disabling
# this will slightly alter perf, but allows us to deterministically know when to
# log the mm flops table rather than having to spam it for every mm call.
checkpoint_fn = _wrap_with_disable_early_stop(checkpoint_fn)

return ptd_checkpoint_wrapper(
module,
context_fn=selective_checkpointing_context_fn,
checkpoint_fn=checkpoint_fn,
preserve_rng_state=False,
)
elif use_layer_sac:
Expand All @@ -305,7 +363,9 @@ def apply_ac(model: nn.Module, ac_config):
transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config)
model.layers.register_module(layer_id, transformer_block)

logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
logger.info(f"Applied {ac_config.mode} checkpointing to the model")
if ac_config.selective_ac_option == "op" and ac_config.log_mm_flops:
logger.info("Logging enabled for mm flops.")


def apply_compile(model: nn.Module):
Expand Down
Loading