Skip to content

Add option for selective op AC to filter mm shapes based on fqn #1380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions tests/unit_tests/test_activation_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
import torch.nn as nn
from torch.utils.flop_counter import FlopCounterMode

from torchtitan.config_manager import ActivationCheckpoint
from torchtitan.models.llama3.infra.parallelize import apply_ac


class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleDict({"0": TransformerBlock()})

def forward(self, x):
return self.layers["0"](x)


class TransformerBlock(nn.Module):
def __init__(self):
super().__init__()
self.moe = nn.Module()
self.moe.router = nn.Module()
self.moe.router.gate = nn.Linear(512, 512, bias=False)
self.attention = nn.Module()
self.attention.wq = nn.Linear(512, 512, bias=False)
self.output = nn.Linear(512, 1024, bias=False)

def forward(self, x):
gate_out = self.moe.router.gate(x)
wq_out = self.attention.wq(gate_out)
final_out = self.output(wq_out)
return final_out.sum()


class TestApplyAC(unittest.TestCase):
def test_flops(self):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is unavailable")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar question: Does AC / FlopCounterMode require GPU to run?


def get_bw_flops(model_fn):
x = torch.randn(512, 512, requires_grad=True, device="cuda")
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
out = model_fn(x)
out.backward()

x = torch.randn(512, 512, requires_grad=True, device="cuda")
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
out = model_fn(x)
with FlopCounterMode(display=False) as mode:
out.backward()
return mode.get_total_flops() / (512**3 * 2)

def get_act_mem(model_fn):
x = torch.randn(512, 512, requires_grad=True, device="cuda")
out = model_fn(x)
out.backward()
start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]

out = model_fn(x)
cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]
act_mem = (cur_mem - start_mem) / (1024 * 1024) # → MB
out.backward()
return act_mem

# 1. No AC
model_no_ac = TestModule().cuda()
flops_no_ac = get_bw_flops(model_no_ac)
mem_no_ac = get_act_mem(model_no_ac)

# 2. SAC
# Per-op SAC's policy is to save every other mm
model_selective_ac = TestModule().cuda()
ac_config_no_force = ActivationCheckpoint(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
)
apply_ac(model_selective_ac, ac_config_no_force)
flops_selective_ac = get_bw_flops(model_selective_ac)
mem_selective_ac = get_act_mem(model_selective_ac)

# 3. Per-op SAC with force recompute "moe.router.gate"
# This leads to two mms being recomputed since they share the same shape!
model_with_force_first = TestModule().cuda()
ac_config_with_force_first = ActivationCheckpoint(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
)
apply_ac(model_with_force_first, ac_config_with_force_first)
flops_with_force_first = get_bw_flops(model_with_force_first)
mem_with_force_first = get_act_mem(model_with_force_first)

# 4. Per-op SAC with force recompute "output"
model_with_force_last = TestModule().cuda()
ac_config_with_force_last = ActivationCheckpoint(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
)
apply_ac(model_with_force_last, ac_config_with_force_last)
flops_with_force_last = get_bw_flops(model_with_force_last)
mem_with_force_last = get_act_mem(model_with_force_last)

# 5. Full AC
model_with_full_ac = TestModule().cuda()
ac_config_full_ac = ActivationCheckpoint(
mode="full",
)
apply_ac(model_with_full_ac, ac_config_full_ac)
flops_full_ac = get_bw_flops(model_with_full_ac)
mem_full_ac = get_act_mem(model_with_full_ac)

self.assertEqual(flops_no_ac, 8.0)
self.assertEqual(flops_selective_ac, 9.0)
self.assertEqual(flops_with_force_first, 10.0)
self.assertEqual(flops_with_force_last, 11.0)
self.assertEqual(flops_full_ac, 12.0)

self.assertEqual(mem_no_ac, 2.0)
self.assertEqual(mem_selective_ac, 3.0)
self.assertEqual(mem_with_force_first, 2.0)
self.assertEqual(mem_with_force_last, 1.0)
self.assertEqual(mem_full_ac, 0.0)
# Note: SAC > no-AC here because it unnecessarily saves "output"
# even that is not needed for recomputaion and output is double
# the size of the other two mms.

def test_correctness(self):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is unavailable")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe noob question:
Does AC require GPU to run? My intuition was it should be able to run on CPU.


model_no_ac = TestModule().cuda()

model_selective_ac = TestModule().cuda()
model_selective_ac.load_state_dict(model_no_ac.state_dict())
apply_ac(
model_selective_ac,
ActivationCheckpoint(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
),
)
model_force_first = TestModule().cuda()
model_force_first.load_state_dict(model_no_ac.state_dict())
apply_ac(
model_force_first,
ActivationCheckpoint(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
),
)

model_force_last = TestModule().cuda()
model_force_last.load_state_dict(model_no_ac.state_dict())
apply_ac(
model_force_last,
ActivationCheckpoint(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
),
)

def run_fwd_bwd(model, batch):
model.zero_grad(set_to_none=True)
xin = batch.clone().detach().requires_grad_(True)
out = model(xin) # scalar
out.backward()

grad_in = xin.grad.detach().clone()
grad_params = [
p.grad.detach().clone() if isinstance(p.grad, torch.Tensor) else None
for p in model.parameters()
]
return out.detach(), grad_in, grad_params

batch = torch.randn(64, 512, device="cuda")

out_ref, gin_ref, gparams_ref = run_fwd_bwd(model_no_ac, batch)
out_sel, gin_sel, gparams_sel = run_fwd_bwd(model_selective_ac, batch)
out_f1, gin_f1, gparams_f1 = run_fwd_bwd(model_force_first, batch)
out_fl, gin_fl, gparams_fl = run_fwd_bwd(model_force_last, batch)

for other_out in (out_sel, out_f1, out_fl):
torch.testing.assert_close(out_ref, other_out)

for other_gin in (gin_sel, gin_f1, gin_fl):
torch.testing.assert_close(gin_ref, other_gin)

for g_ref, g_sel, g_f1, g_fl in zip(
gparams_ref, gparams_sel, gparams_f1, gparams_fl
):
# Skip wrapper / missing grads
if not (
torch.is_tensor(g_ref)
and torch.is_tensor(g_sel)
and torch.is_tensor(g_f1)
and torch.is_tensor(g_fl)
):
continue

torch.testing.assert_close(g_ref, g_sel)
torch.testing.assert_close(g_ref, g_f1)
torch.testing.assert_close(g_ref, g_fl)


if __name__ == "__main__":
unittest.main()
14 changes: 14 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,20 @@ class ActivationCheckpoint:
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
"""

per_op_sac_force_recompute_mm_shapes_by_fqns: list[str] = field(
default_factory=lambda: ["moe.router.gate"]
)
"""
When per-op selective ac is used, this list of fully qualified names (relative
to the module at which AC is applied) is used to determine which mm shapes to
force recompute, rather than being considered by rest of the sac policy, e.g
save every other mm. Only nn.Linear modules are supported today.

Note: this config applies to mms not limited to those matching the specified
fqns, e.g. if "moe.router.gate", corresponding to Linear(in, out), is specified,
ANY mm with shape matching (*, in) x (in, out) will be force recomputed.
"""


@dataclass
class Float8:
Expand Down
34 changes: 30 additions & 4 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
SequenceParallel,
)

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.config_manager import ActivationCheckpoint, JobConfig, TORCH_DTYPE_MAP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe ActivationCheckpoint as ACConfig

from torchtitan.distributed import ParallelDims
from torchtitan.tools.logging import logger

Expand Down Expand Up @@ -237,7 +237,9 @@ def apply_tp(
}


def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
def _apply_ac_to_transformer_block(
module: nn.Module, ac_config: ActivationCheckpoint, base_fqn: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since in torchtitan we only apply AC at transformer block level, I feel the arg base_fqn is less needed, in the sense that there should be rare cases where user apply per op SAC, but only wants to filter router.gate matmul in layer 1 but not layer 2.

Most use cases would be per_op_sac_force_recompute_mm_shapes_by_fqns = ["moe.router.gate"] and moe.router.gate should be already in module_fqn without base_fqn.

If that's the case, I think it's not necessary to add this field. Let me know if you think otherwise.

):
valid_ac_modes = ("full", "selective")
if ac_config.mode not in valid_ac_modes:
raise ValueError(
Expand All @@ -261,11 +263,33 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
create_selective_checkpoint_contexts,
)

mm_recompute_shapes = set()
if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0:
for module_fqn, submod in module.named_modules():
fqn = f"{base_fqn}.{module_fqn}"
if not any(
filter_fqn in fqn
for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns
):
continue
if not isinstance(submod, nn.Linear):
raise ValueError(
"per_op_sac_force_recompute_mm_shapes_by_fqns expected to match "
f"a nn.Linear, but got: {submod}"
)
out_f, in_f = submod.weight.shape
mm_recompute_shapes.add((in_f, out_f))
logger.debug(
f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}"
)

def _get_custom_policy(meta):
def _custom_policy(ctx, func, *args, **kwargs):
mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
if args[1].shape in mm_recompute_shapes:
return CheckpointPolicy.PREFER_RECOMPUTE
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
to_save = func in _save_list and not (
Expand Down Expand Up @@ -299,10 +323,12 @@ def selective_checkpointing_context_fn():
return module


def apply_ac(model: nn.Module, ac_config):
def apply_ac(model: nn.Module, ac_config: ActivationCheckpoint):
"""Apply activation checkpointing to the model."""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config)
transformer_block = _apply_ac_to_transformer_block(
transformer_block, ac_config, f"layers.{layer_id}"
)
model.layers.register_module(layer_id, transformer_block)

logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
Expand Down
Loading