Skip to content

Commit 2c8024a

Browse files
committed
Add option to exclude low flop mms from every-other-mm sac policy
ghstack-source-id: c233523 Pull-Request-resolved: #1372
1 parent 01f4e50 commit 2c8024a

File tree

3 files changed

+78
-4
lines changed

3 files changed

+78
-4
lines changed

torchtitan/config_manager.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,18 @@ class ActivationCheckpoint:
487487
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
488488
"""
489489

490+
selective_op_ac_mm_flops_threshold: int = 0
491+
"""
492+
When selective_ac_option is 'op', this threshold is used to determine whether to
493+
save a given mm, e.g. 1e5 means excluding mms flops < 1e5, and then saving
494+
every other mm from the remaining mms.
495+
"""
496+
497+
log_mm_flops: bool = False
498+
"""
499+
Whether to log the distribution of mm flops. This can be useful for determining
500+
the appropriate threshold for selective_op_ac_mm_flops_threshold.
501+
"""
490502

491503
@dataclass
492504
class Float8:

torchtitan/experiments/llama4/train_configs/debug_model.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,10 @@ export_dtype = "float32"
6363
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
6464

6565
[activation_checkpoint]
66-
mode = "none" # ["none", "selective", "full"]
67-
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy
66+
mode = "selective" # ["none", "selective", "full"]
67+
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
68+
selective_op_ac_mm_flops_threshold = 0 # checking if everything is recomputed
69+
log_mm_flops = true
6870

6971
[float8]
7072
enable_fsdp_float8_all_gather = false

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
99

1010
from collections import defaultdict
11+
from decimal import Decimal
12+
import functools
1113

1214
import torch
1315
import torch.nn as nn
@@ -237,6 +239,28 @@ def apply_tp(
237239
}
238240

239241

242+
def _format_mm_flops_table(entries):
243+
header = ("MM Shape", "FLOPs", "Count")
244+
rows = [header] + [(k[0], k[1], str(v)) for k, v in entries.items()]
245+
col0 = max(len(r[0]) for r in rows)
246+
col1 = max(len(r[1]) for r in rows)
247+
col2 = max(len(r[2]) for r in rows)
248+
lines = [
249+
f"| {'MM Shape'.ljust(col0)} | {'FLOPs'.ljust(col1)} | {'Count'.ljust(col2)} |",
250+
f"| {'-' * col0} | {'-' * col1} | {'-' * col2} |",
251+
]
252+
for s, fl, cnt in rows[1:]:
253+
lines.append(f"| {s.ljust(col0)} | {fl.ljust(col1)} | {cnt.ljust(col2)} |")
254+
return "\n".join(lines)
255+
256+
257+
def _wrap_with_disable_early_stop(fn):
258+
def inner(*args, **kwargs):
259+
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
260+
return fn(*args, **kwargs)
261+
return inner
262+
263+
240264
def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
241265
valid_ac_modes = ("full", "selective")
242266
if ac_config.mode not in valid_ac_modes:
@@ -264,12 +288,38 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
264288
def _get_custom_policy(meta):
265289
def _custom_policy(ctx, func, *args, **kwargs):
266290
mode = "recompute" if ctx.is_recompute else "forward"
291+
292+
mm_count_key_filtered = f"{mode}_mm_count_filtered"
267293
mm_count_key = f"{mode}_mm_count"
294+
268295
if func == torch.ops.aten.mm.default:
269296
meta[mm_count_key] += 1
297+
298+
m, k = args[0].shape
299+
k2, n = args[1].shape
300+
assert k == k2
301+
flops = m * n * 2 * k
302+
303+
if ac_config.log_mm_flops and ctx.is_recompute:
304+
shape_str = f"({m}x{k}) x ({k2}x{n})"
305+
flops_str = f"{Decimal(flops):.2E}"
306+
key = (shape_str, flops_str)
307+
meta[key] += 1
308+
if meta["recompute_mm_count"] == meta["forward_mm_count"]:
309+
table = _format_mm_flops_table({k:v for k,v in meta.items() if "mm_count" not in k})
310+
logger.info("\n%s", table)
311+
312+
# Filter out ops below are certain flop threshold. See discussion for why we
313+
# recompute instead of save here:
314+
# https://github.com/pytorch/torchtitan/pull/1372#discussion_r2193722200
315+
if flops < ac_config.selective_op_ac_mm_flops_threshold:
316+
return CheckpointPolicy.PREFER_RECOMPUTE
317+
318+
meta[mm_count_key_filtered] += 1
319+
270320
# Saves output of all compute ops, except every second mm
271321
to_save = func in _save_list and not (
272-
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
322+
func == torch.ops.aten.mm.default and meta[mm_count_key_filtered] % 2 == 0
273323
)
274324
return (
275325
CheckpointPolicy.MUST_SAVE
@@ -283,9 +333,17 @@ def selective_checkpointing_context_fn():
283333
meta = defaultdict(int)
284334
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
285335

336+
checkpoint_fn = functools.partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
337+
if ac_config.log_mm_flops:
338+
# If early-stop is enabled, fewer mm are recomputed than in forward. Disabling
339+
# this will slightly alter perf, but allows us to deterministically know when to
340+
# log the mm flops table rather than having to spam it for every mm call.
341+
checkpoint_fn = _wrap_with_disable_early_stop(checkpoint_fn)
342+
286343
return ptd_checkpoint_wrapper(
287344
module,
288345
context_fn=selective_checkpointing_context_fn,
346+
checkpoint_fn=checkpoint_fn,
289347
preserve_rng_state=False,
290348
)
291349
elif use_layer_sac:
@@ -305,7 +363,9 @@ def apply_ac(model: nn.Module, ac_config):
305363
transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config)
306364
model.layers.register_module(layer_id, transformer_block)
307365

308-
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
366+
logger.info(f"Applied {ac_config.mode} checkpointing to the model")
367+
if ac_config.selective_ac_option == "op" and ac_config.log_mm_flops:
368+
logger.info("Logging enabled for mm flops.")
309369

310370

311371
def apply_compile(model: nn.Module):

0 commit comments

Comments
 (0)