-
Notifications
You must be signed in to change notification settings - Fork 426
Add option for selective op AC to filter mm shapes based on fqn #1380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.utils.flop_counter import FlopCounterMode | ||
|
||
from torchtitan.config_manager import ActivationCheckpoint | ||
from torchtitan.models.llama3.infra.parallelize import apply_ac | ||
|
||
|
||
class TestModule(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.layers = nn.ModuleDict({"0": TransformerBlock()}) | ||
|
||
def forward(self, x): | ||
return self.layers["0"](x) | ||
|
||
|
||
class TransformerBlock(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.moe = nn.Module() | ||
self.moe.router = nn.Module() | ||
self.moe.router.gate = nn.Linear(512, 512, bias=False) | ||
self.attention = nn.Module() | ||
self.attention.wq = nn.Linear(512, 512, bias=False) | ||
self.output = nn.Linear(512, 1024, bias=False) | ||
|
||
def forward(self, x): | ||
gate_out = self.moe.router.gate(x) | ||
wq_out = self.attention.wq(gate_out) | ||
final_out = self.output(wq_out) | ||
return final_out.sum() | ||
|
||
|
||
class TestApplyAC(unittest.TestCase): | ||
def test_flops(self): | ||
if not torch.cuda.is_available(): | ||
raise unittest.SkipTest("CUDA is unavailable") | ||
|
||
def get_bw_flops(model_fn): | ||
x = torch.randn(512, 512, requires_grad=True, device="cuda") | ||
with torch.utils.checkpoint.set_checkpoint_early_stop(False): | ||
out = model_fn(x) | ||
out.backward() | ||
|
||
x = torch.randn(512, 512, requires_grad=True, device="cuda") | ||
with torch.utils.checkpoint.set_checkpoint_early_stop(False): | ||
out = model_fn(x) | ||
with FlopCounterMode(display=False) as mode: | ||
out.backward() | ||
return mode.get_total_flops() / (512**3 * 2) | ||
|
||
def get_act_mem(model_fn): | ||
x = torch.randn(512, 512, requires_grad=True, device="cuda") | ||
out = model_fn(x) | ||
out.backward() | ||
start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] | ||
|
||
out = model_fn(x) | ||
cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] | ||
act_mem = (cur_mem - start_mem) / (1024 * 1024) # → MB | ||
out.backward() | ||
return act_mem | ||
|
||
# 1. No AC | ||
model_no_ac = TestModule().cuda() | ||
flops_no_ac = get_bw_flops(model_no_ac) | ||
mem_no_ac = get_act_mem(model_no_ac) | ||
|
||
# 2. SAC | ||
# Per-op SAC's policy is to save every other mm | ||
model_selective_ac = TestModule().cuda() | ||
ac_config_no_force = ActivationCheckpoint( | ||
mode="selective", | ||
selective_ac_option="op", | ||
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list | ||
) | ||
apply_ac(model_selective_ac, ac_config_no_force) | ||
flops_selective_ac = get_bw_flops(model_selective_ac) | ||
mem_selective_ac = get_act_mem(model_selective_ac) | ||
|
||
# 3. Per-op SAC with force recompute "moe.router.gate" | ||
# This leads to two mms being recomputed since they share the same shape! | ||
model_with_force_first = TestModule().cuda() | ||
ac_config_with_force_first = ActivationCheckpoint( | ||
mode="selective", | ||
selective_ac_option="op", | ||
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], | ||
) | ||
apply_ac(model_with_force_first, ac_config_with_force_first) | ||
flops_with_force_first = get_bw_flops(model_with_force_first) | ||
mem_with_force_first = get_act_mem(model_with_force_first) | ||
|
||
# 4. Per-op SAC with force recompute "output" | ||
model_with_force_last = TestModule().cuda() | ||
ac_config_with_force_last = ActivationCheckpoint( | ||
mode="selective", | ||
selective_ac_option="op", | ||
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], | ||
) | ||
apply_ac(model_with_force_last, ac_config_with_force_last) | ||
flops_with_force_last = get_bw_flops(model_with_force_last) | ||
mem_with_force_last = get_act_mem(model_with_force_last) | ||
|
||
# 5. Full AC | ||
model_with_full_ac = TestModule().cuda() | ||
ac_config_full_ac = ActivationCheckpoint( | ||
mode="full", | ||
) | ||
apply_ac(model_with_full_ac, ac_config_full_ac) | ||
flops_full_ac = get_bw_flops(model_with_full_ac) | ||
mem_full_ac = get_act_mem(model_with_full_ac) | ||
|
||
self.assertEqual(flops_no_ac, 8.0) | ||
self.assertEqual(flops_selective_ac, 9.0) | ||
self.assertEqual(flops_with_force_first, 10.0) | ||
self.assertEqual(flops_with_force_last, 11.0) | ||
self.assertEqual(flops_full_ac, 12.0) | ||
|
||
self.assertEqual(mem_no_ac, 2.0) | ||
self.assertEqual(mem_selective_ac, 3.0) | ||
self.assertEqual(mem_with_force_first, 2.0) | ||
self.assertEqual(mem_with_force_last, 1.0) | ||
self.assertEqual(mem_full_ac, 0.0) | ||
# Note: SAC > no-AC here because it unnecessarily saves "output" | ||
# even that is not needed for recomputaion and output is double | ||
# the size of the other two mms. | ||
|
||
def test_correctness(self): | ||
if not torch.cuda.is_available(): | ||
raise unittest.SkipTest("CUDA is unavailable") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe noob question: |
||
|
||
model_no_ac = TestModule().cuda() | ||
|
||
model_selective_ac = TestModule().cuda() | ||
model_selective_ac.load_state_dict(model_no_ac.state_dict()) | ||
apply_ac( | ||
model_selective_ac, | ||
ActivationCheckpoint( | ||
mode="selective", | ||
selective_ac_option="op", | ||
per_op_sac_force_recompute_mm_shapes_by_fqns=[], | ||
), | ||
) | ||
model_force_first = TestModule().cuda() | ||
model_force_first.load_state_dict(model_no_ac.state_dict()) | ||
apply_ac( | ||
model_force_first, | ||
ActivationCheckpoint( | ||
mode="selective", | ||
selective_ac_option="op", | ||
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], | ||
), | ||
) | ||
|
||
model_force_last = TestModule().cuda() | ||
model_force_last.load_state_dict(model_no_ac.state_dict()) | ||
apply_ac( | ||
model_force_last, | ||
ActivationCheckpoint( | ||
mode="selective", | ||
selective_ac_option="op", | ||
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], | ||
), | ||
) | ||
|
||
def run_fwd_bwd(model, batch): | ||
model.zero_grad(set_to_none=True) | ||
xin = batch.clone().detach().requires_grad_(True) | ||
out = model(xin) # scalar | ||
out.backward() | ||
|
||
grad_in = xin.grad.detach().clone() | ||
grad_params = [ | ||
p.grad.detach().clone() if isinstance(p.grad, torch.Tensor) else None | ||
for p in model.parameters() | ||
] | ||
return out.detach(), grad_in, grad_params | ||
|
||
batch = torch.randn(64, 512, device="cuda") | ||
|
||
out_ref, gin_ref, gparams_ref = run_fwd_bwd(model_no_ac, batch) | ||
out_sel, gin_sel, gparams_sel = run_fwd_bwd(model_selective_ac, batch) | ||
out_f1, gin_f1, gparams_f1 = run_fwd_bwd(model_force_first, batch) | ||
out_fl, gin_fl, gparams_fl = run_fwd_bwd(model_force_last, batch) | ||
|
||
for other_out in (out_sel, out_f1, out_fl): | ||
torch.testing.assert_close(out_ref, other_out) | ||
|
||
for other_gin in (gin_sel, gin_f1, gin_fl): | ||
torch.testing.assert_close(gin_ref, other_gin) | ||
|
||
for g_ref, g_sel, g_f1, g_fl in zip( | ||
gparams_ref, gparams_sel, gparams_f1, gparams_fl | ||
): | ||
# Skip wrapper / missing grads | ||
if not ( | ||
torch.is_tensor(g_ref) | ||
and torch.is_tensor(g_sel) | ||
and torch.is_tensor(g_f1) | ||
and torch.is_tensor(g_fl) | ||
): | ||
continue | ||
|
||
torch.testing.assert_close(g_ref, g_sel) | ||
torch.testing.assert_close(g_ref, g_f1) | ||
torch.testing.assert_close(g_ref, g_fl) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ | |
SequenceParallel, | ||
) | ||
|
||
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP | ||
from torchtitan.config_manager import ActivationCheckpoint, JobConfig, TORCH_DTYPE_MAP | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe |
||
from torchtitan.distributed import ParallelDims | ||
from torchtitan.tools.logging import logger | ||
|
||
|
@@ -237,7 +237,9 @@ def apply_tp( | |
} | ||
|
||
|
||
def _apply_ac_to_transformer_block(module: nn.Module, ac_config): | ||
def _apply_ac_to_transformer_block( | ||
module: nn.Module, ac_config: ActivationCheckpoint, base_fqn: str | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since in torchtitan we only apply AC at transformer block level, I feel the arg Most use cases would be If that's the case, I think it's not necessary to add this field. Let me know if you think otherwise. |
||
): | ||
valid_ac_modes = ("full", "selective") | ||
if ac_config.mode not in valid_ac_modes: | ||
raise ValueError( | ||
|
@@ -261,11 +263,33 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config): | |
create_selective_checkpoint_contexts, | ||
) | ||
|
||
mm_recompute_shapes = set() | ||
if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: | ||
for module_fqn, submod in module.named_modules(): | ||
fqn = f"{base_fqn}.{module_fqn}" | ||
if not any( | ||
filter_fqn in fqn | ||
for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns | ||
): | ||
continue | ||
if not isinstance(submod, nn.Linear): | ||
raise ValueError( | ||
"per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " | ||
f"a nn.Linear, but got: {submod}" | ||
) | ||
out_f, in_f = submod.weight.shape | ||
mm_recompute_shapes.add((in_f, out_f)) | ||
logger.debug( | ||
f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" | ||
) | ||
|
||
def _get_custom_policy(meta): | ||
def _custom_policy(ctx, func, *args, **kwargs): | ||
mode = "recompute" if ctx.is_recompute else "forward" | ||
mm_count_key = f"{mode}_mm_count" | ||
if func == torch.ops.aten.mm.default: | ||
if args[1].shape in mm_recompute_shapes: | ||
return CheckpointPolicy.PREFER_RECOMPUTE | ||
meta[mm_count_key] += 1 | ||
# Saves output of all compute ops, except every second mm | ||
to_save = func in _save_list and not ( | ||
|
@@ -299,10 +323,12 @@ def selective_checkpointing_context_fn(): | |
return module | ||
|
||
|
||
def apply_ac(model: nn.Module, ac_config): | ||
def apply_ac(model: nn.Module, ac_config: ActivationCheckpoint): | ||
"""Apply activation checkpointing to the model.""" | ||
for layer_id, transformer_block in model.layers.named_children(): | ||
transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config) | ||
transformer_block = _apply_ac_to_transformer_block( | ||
transformer_block, ac_config, f"layers.{layer_id}" | ||
) | ||
model.layers.register_module(layer_id, transformer_block) | ||
|
||
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar question: Does AC /
FlopCounterMode
require GPU to run?