From e3a906eb7e1e464f1775fd32c36e1520cae8f66c Mon Sep 17 00:00:00 2001 From: soulitzer Date: Tue, 8 Jul 2025 11:59:00 -0700 Subject: [PATCH 1/4] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 9 +++++++++ torchtitan/models/llama3/infra/parallelize.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d40a5982f..2de89bcd7 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -486,6 +486,15 @@ class ActivationCheckpoint: Selective activation checkpointing options ['int', 'op']. 'int' (e.g., 2) for every nth layer, or 'op' for op level ac. """ + selective_op_ac_mm_flops_threshold: int = 0 + """ + When selective_ac_option is 'op', this threshold is used to determine whether to + apply save a given mm. + + For example: + - 0 means no threshold; every other mm is saved + - 1e5 means every other mm is saved, excluding mm with flops > 1e5. + """ @dataclass diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index df395adcb..5fa3d946f 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -265,6 +265,15 @@ def _get_custom_policy(meta): def _custom_policy(ctx, func, *args, **kwargs): mode = "recompute" if ctx.is_recompute else "forward" mm_count_key = f"{mode}_mm_count" + + if func == torch.ops.aten.mm.default: + m, k = args[0].shape + k2, n = args[1].shape + assert k == k2 + flops = m * n * 2 * k + if flops < ac_config.selective_op_ac_mm_flops_threshold: + return CheckpointPolicy.PREFER_RECOMPUTE + if func == torch.ops.aten.mm.default: meta[mm_count_key] += 1 # Saves output of all compute ops, except every second mm From 9fa998a90ab3a9a007785d0cb6a0da22856a5363 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Tue, 8 Jul 2025 12:07:13 -0700 Subject: [PATCH 2/4] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2de89bcd7..c6fcb1ecb 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -489,11 +489,8 @@ class ActivationCheckpoint: selective_op_ac_mm_flops_threshold: int = 0 """ When selective_ac_option is 'op', this threshold is used to determine whether to - apply save a given mm. - - For example: - - 0 means no threshold; every other mm is saved - - 1e5 means every other mm is saved, excluding mm with flops > 1e5. + save a given mm, e.g. 1e5 means every other mm is saved, excluding mm with + flops < 1e5. """ From 69172fc00354e85b35e08cf9eb541978379cf323 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Tue, 8 Jul 2025 12:13:15 -0700 Subject: [PATCH 3/4] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index c6fcb1ecb..2af0c5b4a 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -489,8 +489,8 @@ class ActivationCheckpoint: selective_op_ac_mm_flops_threshold: int = 0 """ When selective_ac_option is 'op', this threshold is used to determine whether to - save a given mm, e.g. 1e5 means every other mm is saved, excluding mm with - flops < 1e5. + save a given mm, e.g. 1e5 means excluding mms flops < 1e5, and then saving + every other mm from the remaining mms. """ From 29df79e5629103b636fc6ff34b763525496bc74c Mon Sep 17 00:00:00 2001 From: soulitzer Date: Thu, 10 Jul 2025 09:20:27 -0700 Subject: [PATCH 4/4] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 6 ++ .../llama4/train_configs/debug_model.toml | 6 +- torchtitan/models/llama3/infra/parallelize.py | 59 +++++++++++++++++-- 3 files changed, 65 insertions(+), 6 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2af0c5b4a..0b6983bcc 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -486,6 +486,7 @@ class ActivationCheckpoint: Selective activation checkpointing options ['int', 'op']. 'int' (e.g., 2) for every nth layer, or 'op' for op level ac. """ + selective_op_ac_mm_flops_threshold: int = 0 """ When selective_ac_option is 'op', this threshold is used to determine whether to @@ -493,6 +494,11 @@ class ActivationCheckpoint: every other mm from the remaining mms. """ + log_mm_flops: bool = False + """ + Whether to log the distribution of mm flops. This can be useful for determining + the appropriate threshold for selective_op_ac_mm_flops_threshold. + """ @dataclass class Float8: diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index 7fbe95e19..03ea33465 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -63,8 +63,10 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = "none" # ["none", "selective", "full"] -selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy +selective_op_ac_mm_flops_threshold = 0 # checking if everything is recomputed +log_mm_flops = true [float8] enable_fsdp_float8_all_gather = false diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 5fa3d946f..aabc6cdf9 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -8,6 +8,8 @@ # training techniques (e.g. activation checkpointing and compile) to the Llama model. from collections import defaultdict +from decimal import Decimal +import functools import torch import torch.nn as nn @@ -237,6 +239,28 @@ def apply_tp( } +def _format_mm_flops_table(entries): + header = ("MM Shape", "FLOPs", "Count") + rows = [header] + [(k[0], k[1], str(v)) for k, v in entries.items()] + col0 = max(len(r[0]) for r in rows) + col1 = max(len(r[1]) for r in rows) + col2 = max(len(r[2]) for r in rows) + lines = [ + f"| {'MM Shape'.ljust(col0)} | {'FLOPs'.ljust(col1)} | {'Count'.ljust(col2)} |", + f"| {'-' * col0} | {'-' * col1} | {'-' * col2} |", + ] + for s, fl, cnt in rows[1:]: + lines.append(f"| {s.ljust(col0)} | {fl.ljust(col1)} | {cnt.ljust(col2)} |") + return "\n".join(lines) + + +def _wrap_with_disable_early_stop(fn): + def inner(*args, **kwargs): + with torch.utils.checkpoint.set_checkpoint_early_stop(False): + return fn(*args, **kwargs) + return inner + + def _apply_ac_to_transformer_block(module: nn.Module, ac_config): valid_ac_modes = ("full", "selective") if ac_config.mode not in valid_ac_modes: @@ -264,21 +288,38 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config): def _get_custom_policy(meta): def _custom_policy(ctx, func, *args, **kwargs): mode = "recompute" if ctx.is_recompute else "forward" + + mm_count_key_filtered = f"{mode}_mm_count_filtered" mm_count_key = f"{mode}_mm_count" if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + m, k = args[0].shape k2, n = args[1].shape assert k == k2 flops = m * n * 2 * k + + if ac_config.log_mm_flops and ctx.is_recompute: + shape_str = f"({m}x{k}) x ({k2}x{n})" + flops_str = f"{Decimal(flops):.2E}" + key = (shape_str, flops_str) + meta[key] += 1 + if meta["recompute_mm_count"] == meta["forward_mm_count"]: + table = _format_mm_flops_table({k:v for k,v in meta.items() if "mm_count" not in k}) + logger.info("\n%s", table) + + # Filter out ops below are certain flop threshold. See discussion for why we + # recompute instead of save here: + # https://github.com/pytorch/torchtitan/pull/1372#discussion_r2193722200 if flops < ac_config.selective_op_ac_mm_flops_threshold: return CheckpointPolicy.PREFER_RECOMPUTE - if func == torch.ops.aten.mm.default: - meta[mm_count_key] += 1 + meta[mm_count_key_filtered] += 1 + # Saves output of all compute ops, except every second mm to_save = func in _save_list and not ( - func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + func == torch.ops.aten.mm.default and meta[mm_count_key_filtered] % 2 == 0 ) return ( CheckpointPolicy.MUST_SAVE @@ -292,9 +333,17 @@ def selective_checkpointing_context_fn(): meta = defaultdict(int) return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + checkpoint_fn = functools.partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if ac_config.log_mm_flops: + # If early-stop is enabled, fewer mm are recomputed than in forward. Disabling + # this will slightly alter perf, but allows us to deterministically know when to + # log the mm flops table rather than having to spam it for every mm call. + checkpoint_fn = _wrap_with_disable_early_stop(checkpoint_fn) + return ptd_checkpoint_wrapper( module, context_fn=selective_checkpointing_context_fn, + checkpoint_fn=checkpoint_fn, preserve_rng_state=False, ) elif use_layer_sac: @@ -314,7 +363,9 @@ def apply_ac(model: nn.Module, ac_config): transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config) model.layers.register_module(layer_id, transformer_block) - logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + logger.info(f"Applied {ac_config.mode} checkpointing to the model") + if ac_config.selective_ac_option == "op" and ac_config.log_mm_flops: + logger.info("Logging enabled for mm flops.") def apply_compile(model: nn.Module):