diff --git a/.github/workflows/integration_test_8gpu.yaml b/.github/workflows/integration_test_8gpu.yaml index bf534746f..d5db8b0eb 100644 --- a/.github/workflows/integration_test_8gpu.yaml +++ b/.github/workflows/integration_test_8gpu.yaml @@ -36,6 +36,8 @@ jobs: pip config --user set global.progress_bar off + git clone https://github.com/soulitzer/ac-experimental.git && cd ac-experimental && pip install -e . && cd .. + python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 diff --git a/torchtitan/models/llama3/parallelize_llama.py b/torchtitan/models/llama3/parallelize_llama.py index 64dfcad23..28bb6d0ba 100644 --- a/torchtitan/models/llama3/parallelize_llama.py +++ b/torchtitan/models/llama3/parallelize_llama.py @@ -12,10 +12,36 @@ import torch import torch.nn as nn from torch.distributed._composable.replicate import replicate + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper as ptd_checkpoint_wrapper, + ActivationWrapper, ) + +class CheckpointWrapper(ActivationWrapper): + def __init__(self, mod: torch.nn.Module, **kwargs): + super().__init__(mod) + self._checkpoint_wrapped_module = mod + self._make_policy_fn = kwargs.get("make_policy_fn", None) + + def forward(self, *args, **kwargs): + from ac_experimental import apply_ac_policy_fn + + if self._make_policy_fn is None: + return apply_ac_policy_fn( + self._checkpoint_wrapped_module, *args, **kwargs, policy_fn="recompute_all" + ) + else: + # Pass is_factory=True so that a new instance of policy_fn is created per AC invocation + return apply_ac_policy_fn( + self._checkpoint_wrapped_module, *args, **kwargs, policy_fn=self._make_policy_fn, is_factory=True + ) + + +def ptd_checkpoint_wrapper(mod, **kwargs): + return CheckpointWrapper(mod, **kwargs) + + from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy from torch.distributed.tensor import Replicate, Shard @@ -226,6 +252,29 @@ def apply_tp( torch.ops.aten.max.default, } +from torch.utils.checkpoint import CheckpointPolicy + +# If you want your policy to have state, pass a class. Make sure to +# create it in global scope to avoid new instances triggering recompiles. +class CustomPolicy: + def __init__(self): + super().__init__() + self.meta = dict() + + def __call__(self, ctx, out, func, *args, **kwargs): + mm_count_key = f"mm_count" + if func == torch.ops.aten.mm.default: + self.meta[mm_count_key] = self.meta.get(mm_count_key, 0) + 1 + + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not ( + func == torch.ops.aten.mm.default and self.meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) def _apply_ac_to_transformer_block(module: nn.Module, ac_config): valid_ac_modes = ("full", "selective") @@ -246,38 +295,11 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config): f"Valid options: 'op' or a positive int representing layer frequency" ) if use_op_sac: - from torch.utils.checkpoint import ( - CheckpointPolicy, - create_selective_checkpoint_contexts, - ) - - def _get_custom_policy(meta): - def _custom_policy(ctx, func, *args, **kwargs): - mode = "recompute" if ctx.is_recompute else "forward" - mm_count_key = f"{mode}_mm_count" - if func == torch.ops.aten.mm.default: - meta[mm_count_key] += 1 - # Saves output of all compute ops, except every second mm - to_save = func in _save_list and not ( - func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 - ) - return ( - CheckpointPolicy.MUST_SAVE - if to_save - else CheckpointPolicy.PREFER_RECOMPUTE - ) - - return _custom_policy - - def selective_checkpointing_context_fn(): - meta = defaultdict(int) - return create_selective_checkpoint_contexts(_get_custom_policy(meta)) - return ptd_checkpoint_wrapper( module, - context_fn=selective_checkpointing_context_fn, - preserve_rng_state=False, + make_policy_fn=CustomPolicy, ) + elif use_layer_sac: # Checkpoint every `ac_freq` of the modules passed to this function ac_freq = int(ac_config.selective_ac_option)