From f2dae4f5dde4b181ec0842866d5f52569ce4420f Mon Sep 17 00:00:00 2001 From: soulitzer Date: Fri, 13 Jun 2025 05:55:47 -0700 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- torchtitan/models/llama3/parallelize_llama.py | 73 ++++++++++++++++--- 1 file changed, 62 insertions(+), 11 deletions(-) diff --git a/torchtitan/models/llama3/parallelize_llama.py b/torchtitan/models/llama3/parallelize_llama.py index 64dfcad23..7e0cca5c3 100644 --- a/torchtitan/models/llama3/parallelize_llama.py +++ b/torchtitan/models/llama3/parallelize_llama.py @@ -12,10 +12,35 @@ import torch import torch.nn as nn from torch.distributed._composable.replicate import replicate + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper as ptd_checkpoint_wrapper, + ActivationWrapper, + # checkpoint_wrapper as ptd_checkpoint_wrapper, ) + +class CheckpointWrapper(ActivationWrapper): + def __init__(self, mod: torch.nn.Module, **kwargs): + super().__init__(mod) + self._checkpoint_wrapped_module = mod + self._policy_fn = kwargs.get("policy_fn", None) + + def forward(self, *args, **kwargs): + from ac_experimental import apply_ac_policy + + policy = self._policy_fn if self._policy_fn is not None else "recompute_all" + + with apply_ac_policy(policy_fn=policy): + return self._checkpoint_wrapped_module(*args, **kwargs) + # return apply_ac_policy_fn( + # self._checkpoint_wrapped_module, *args, **kwargs, policy_fn="recompute_all" + # ) + + +def ptd_checkpoint_wrapper(mod, **kwargs): + return CheckpointWrapper(mod, **kwargs) + + from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy from torch.distributed.tensor import Replicate, Shard @@ -251,9 +276,39 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config): create_selective_checkpoint_contexts, ) - def _get_custom_policy(meta): - def _custom_policy(ctx, func, *args, **kwargs): - mode = "recompute" if ctx.is_recompute else "forward" + # def _get_custom_policy(meta): + # def _custom_policy(ctx, func, *args, **kwargs): + # mode = "recompute" if ctx.is_recompute else "forward" + # mm_count_key = f"{mode}_mm_count" + # if func == torch.ops.aten.mm.default: + # meta[mm_count_key] += 1 + # # Saves output of all compute ops, except every second mm + # # to_save = func in _save_list and not ( + # # func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + # # ) + # return ( + # CheckpointPolicy.MUST_SAVE + # if func in _save_list + # else CheckpointPolicy.PREFER_RECOMPUTE + # ) + + # return _custom_policy + + # def selective_checkpointing_context_fn(): + # meta = defaultdict(int) + # return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + # return ptd_checkpoint_wrapper( + # module, + # context_fn=selective_checkpointing_context_fn, + # preserve_rng_state=False, + # ) + + def _get_custom_policy(): + meta = defaultdict(int) + + def _custom_policy(ctx, out, func, *args, **kwargs): + mode = "forward" # recompute" if ctx.is_recompute else "forward" mm_count_key = f"{mode}_mm_count" if func == torch.ops.aten.mm.default: meta[mm_count_key] += 1 @@ -263,21 +318,17 @@ def _custom_policy(ctx, func, *args, **kwargs): ) return ( CheckpointPolicy.MUST_SAVE - if to_save + if func in _save_list else CheckpointPolicy.PREFER_RECOMPUTE ) return _custom_policy - def selective_checkpointing_context_fn(): - meta = defaultdict(int) - return create_selective_checkpoint_contexts(_get_custom_policy(meta)) - return ptd_checkpoint_wrapper( module, - context_fn=selective_checkpointing_context_fn, - preserve_rng_state=False, + policy_fn=_get_custom_policy(), ) + elif use_layer_sac: # Checkpoint every `ac_freq` of the modules passed to this function ac_freq = int(ac_config.selective_ac_option) From 0d07755753531522a586626836ad4617182dc263 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Fri, 13 Jun 2025 09:07:25 -0700 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- torchtitan/models/llama3/parallelize_llama.py | 99 +++++++------------ 1 file changed, 35 insertions(+), 64 deletions(-) diff --git a/torchtitan/models/llama3/parallelize_llama.py b/torchtitan/models/llama3/parallelize_llama.py index 7e0cca5c3..28bb6d0ba 100644 --- a/torchtitan/models/llama3/parallelize_llama.py +++ b/torchtitan/models/llama3/parallelize_llama.py @@ -15,7 +15,6 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( ActivationWrapper, - # checkpoint_wrapper as ptd_checkpoint_wrapper, ) @@ -23,18 +22,20 @@ class CheckpointWrapper(ActivationWrapper): def __init__(self, mod: torch.nn.Module, **kwargs): super().__init__(mod) self._checkpoint_wrapped_module = mod - self._policy_fn = kwargs.get("policy_fn", None) + self._make_policy_fn = kwargs.get("make_policy_fn", None) def forward(self, *args, **kwargs): - from ac_experimental import apply_ac_policy + from ac_experimental import apply_ac_policy_fn - policy = self._policy_fn if self._policy_fn is not None else "recompute_all" - - with apply_ac_policy(policy_fn=policy): - return self._checkpoint_wrapped_module(*args, **kwargs) - # return apply_ac_policy_fn( - # self._checkpoint_wrapped_module, *args, **kwargs, policy_fn="recompute_all" - # ) + if self._make_policy_fn is None: + return apply_ac_policy_fn( + self._checkpoint_wrapped_module, *args, **kwargs, policy_fn="recompute_all" + ) + else: + # Pass is_factory=True so that a new instance of policy_fn is created per AC invocation + return apply_ac_policy_fn( + self._checkpoint_wrapped_module, *args, **kwargs, policy_fn=self._make_policy_fn, is_factory=True + ) def ptd_checkpoint_wrapper(mod, **kwargs): @@ -251,6 +252,29 @@ def apply_tp( torch.ops.aten.max.default, } +from torch.utils.checkpoint import CheckpointPolicy + +# If you want your policy to have state, pass a class. Make sure to +# create it in global scope to avoid new instances triggering recompiles. +class CustomPolicy: + def __init__(self): + super().__init__() + self.meta = dict() + + def __call__(self, ctx, out, func, *args, **kwargs): + mm_count_key = f"mm_count" + if func == torch.ops.aten.mm.default: + self.meta[mm_count_key] = self.meta.get(mm_count_key, 0) + 1 + + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not ( + func == torch.ops.aten.mm.default and self.meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) def _apply_ac_to_transformer_block(module: nn.Module, ac_config): valid_ac_modes = ("full", "selective") @@ -271,62 +295,9 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config): f"Valid options: 'op' or a positive int representing layer frequency" ) if use_op_sac: - from torch.utils.checkpoint import ( - CheckpointPolicy, - create_selective_checkpoint_contexts, - ) - - # def _get_custom_policy(meta): - # def _custom_policy(ctx, func, *args, **kwargs): - # mode = "recompute" if ctx.is_recompute else "forward" - # mm_count_key = f"{mode}_mm_count" - # if func == torch.ops.aten.mm.default: - # meta[mm_count_key] += 1 - # # Saves output of all compute ops, except every second mm - # # to_save = func in _save_list and not ( - # # func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 - # # ) - # return ( - # CheckpointPolicy.MUST_SAVE - # if func in _save_list - # else CheckpointPolicy.PREFER_RECOMPUTE - # ) - - # return _custom_policy - - # def selective_checkpointing_context_fn(): - # meta = defaultdict(int) - # return create_selective_checkpoint_contexts(_get_custom_policy(meta)) - - # return ptd_checkpoint_wrapper( - # module, - # context_fn=selective_checkpointing_context_fn, - # preserve_rng_state=False, - # ) - - def _get_custom_policy(): - meta = defaultdict(int) - - def _custom_policy(ctx, out, func, *args, **kwargs): - mode = "forward" # recompute" if ctx.is_recompute else "forward" - mm_count_key = f"{mode}_mm_count" - if func == torch.ops.aten.mm.default: - meta[mm_count_key] += 1 - # Saves output of all compute ops, except every second mm - to_save = func in _save_list and not ( - func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 - ) - return ( - CheckpointPolicy.MUST_SAVE - if func in _save_list - else CheckpointPolicy.PREFER_RECOMPUTE - ) - - return _custom_policy - return ptd_checkpoint_wrapper( module, - policy_fn=_get_custom_policy(), + make_policy_fn=CustomPolicy, ) elif use_layer_sac: From 68b2264d01e22fff063638007844b29a7c59fcd1 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Fri, 13 Jun 2025 09:22:18 -0700 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- .github/workflows/integration_test_8gpu.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/integration_test_8gpu.yaml b/.github/workflows/integration_test_8gpu.yaml index bf534746f..d5db8b0eb 100644 --- a/.github/workflows/integration_test_8gpu.yaml +++ b/.github/workflows/integration_test_8gpu.yaml @@ -36,6 +36,8 @@ jobs: pip config --user set global.progress_bar off + git clone https://github.com/soulitzer/ac-experimental.git && cd ac-experimental && pip install -e . && cd .. + python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126