|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import unittest |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | +from torch.utils.flop_counter import FlopCounterMode |
| 12 | + |
| 13 | +from torchtitan.config_manager import ActivationCheckpoint as ACConfig |
| 14 | +from torchtitan.models.llama3.infra.parallelize import apply_ac |
| 15 | + |
| 16 | + |
| 17 | +class TestModule(nn.Module): |
| 18 | + def __init__(self): |
| 19 | + super().__init__() |
| 20 | + self.layers = nn.ModuleDict({"0": TransformerBlock()}) |
| 21 | + |
| 22 | + def forward(self, x): |
| 23 | + return self.layers["0"](x) |
| 24 | + |
| 25 | + |
| 26 | +class TransformerBlock(nn.Module): |
| 27 | + def __init__(self): |
| 28 | + super().__init__() |
| 29 | + self.moe = nn.Module() |
| 30 | + self.moe.router = nn.Module() |
| 31 | + self.moe.router.gate = nn.Linear(512, 512, bias=False) |
| 32 | + self.attention = nn.Module() |
| 33 | + self.attention.wq = nn.Linear(512, 512, bias=False) |
| 34 | + self.output = nn.Linear(512, 1024, bias=False) |
| 35 | + |
| 36 | + def forward(self, x): |
| 37 | + gate_out = self.moe.router.gate(x) |
| 38 | + wq_out = self.attention.wq(gate_out) |
| 39 | + final_out = self.output(wq_out) |
| 40 | + return final_out.sum() |
| 41 | + |
| 42 | + |
| 43 | +class TestApplyAC(unittest.TestCase): |
| 44 | + def test_flops(self): |
| 45 | + def get_bw_flops(model_fn): |
| 46 | + x = torch.randn(512, 512, requires_grad=True) |
| 47 | + with torch.utils.checkpoint.set_checkpoint_early_stop(False): |
| 48 | + out = model_fn(x) |
| 49 | + out.backward() |
| 50 | + |
| 51 | + x = torch.randn(512, 512, requires_grad=True) |
| 52 | + with torch.utils.checkpoint.set_checkpoint_early_stop(False): |
| 53 | + out = model_fn(x) |
| 54 | + with FlopCounterMode(display=False) as mode: |
| 55 | + out.backward() |
| 56 | + return mode.get_total_flops() / (512**3 * 2) |
| 57 | + |
| 58 | + # 1. No AC |
| 59 | + model_no_ac = TestModule() |
| 60 | + flops_no_ac = get_bw_flops(model_no_ac) |
| 61 | + |
| 62 | + # 2. SAC |
| 63 | + # Per-op SAC's policy is to save every other mm |
| 64 | + model_selective_ac = TestModule() |
| 65 | + ac_config_no_force = ACConfig( |
| 66 | + mode="selective", |
| 67 | + selective_ac_option="op", |
| 68 | + per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list |
| 69 | + ) |
| 70 | + apply_ac(model_selective_ac, ac_config_no_force) |
| 71 | + flops_selective_ac = get_bw_flops(model_selective_ac) |
| 72 | + |
| 73 | + # 3. Per-op SAC with force recompute "moe.router.gate" |
| 74 | + # This leads to two mms being recomputed since they share the same shape! |
| 75 | + model_with_force_first = TestModule() |
| 76 | + ac_config_with_force_first = ACConfig( |
| 77 | + mode="selective", |
| 78 | + selective_ac_option="op", |
| 79 | + per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], |
| 80 | + ) |
| 81 | + apply_ac(model_with_force_first, ac_config_with_force_first) |
| 82 | + flops_with_force_first = get_bw_flops(model_with_force_first) |
| 83 | + |
| 84 | + # 4. Per-op SAC with force recompute "output" |
| 85 | + model_with_force_last = TestModule() |
| 86 | + ac_config_with_force_last = ACConfig( |
| 87 | + mode="selective", |
| 88 | + selective_ac_option="op", |
| 89 | + per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], |
| 90 | + ) |
| 91 | + apply_ac(model_with_force_last, ac_config_with_force_last) |
| 92 | + flops_with_force_last = get_bw_flops(model_with_force_last) |
| 93 | + |
| 94 | + # 5. Full AC |
| 95 | + model_with_full_ac = TestModule() |
| 96 | + ac_config_full_ac = ACConfig( |
| 97 | + mode="full", |
| 98 | + ) |
| 99 | + apply_ac(model_with_full_ac, ac_config_full_ac) |
| 100 | + flops_full_ac = get_bw_flops(model_with_full_ac) |
| 101 | + |
| 102 | + self.assertEqual(flops_no_ac, 8.0) |
| 103 | + self.assertEqual(flops_selective_ac, 9.0) |
| 104 | + self.assertEqual(flops_with_force_first, 10.0) |
| 105 | + self.assertEqual(flops_with_force_last, 11.0) |
| 106 | + self.assertEqual(flops_full_ac, 12.0) |
| 107 | + |
| 108 | + def test_mem(self): |
| 109 | + if not torch.cuda.is_available(): |
| 110 | + raise unittest.SkipTest("CUDA is unavailable") |
| 111 | + |
| 112 | + def get_act_mem(model_fn): |
| 113 | + x = torch.randn(512, 512, requires_grad=True, device="cuda") |
| 114 | + out = model_fn(x) |
| 115 | + out.backward() |
| 116 | + start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] |
| 117 | + |
| 118 | + out = model_fn(x) |
| 119 | + cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] |
| 120 | + act_mem = (cur_mem - start_mem) / (1024 * 1024) # → MB |
| 121 | + out.backward() |
| 122 | + return act_mem |
| 123 | + |
| 124 | + # 1. No AC |
| 125 | + model_no_ac = TestModule().cuda() |
| 126 | + mem_no_ac = get_act_mem(model_no_ac) |
| 127 | + |
| 128 | + # 2. SAC |
| 129 | + # Per-op SAC's policy is to save every other mm |
| 130 | + model_selective_ac = TestModule().cuda() |
| 131 | + ac_config_no_force = ACConfig( |
| 132 | + mode="selective", |
| 133 | + selective_ac_option="op", |
| 134 | + per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list |
| 135 | + ) |
| 136 | + apply_ac(model_selective_ac, ac_config_no_force) |
| 137 | + mem_selective_ac = get_act_mem(model_selective_ac) |
| 138 | + |
| 139 | + # 3. Per-op SAC with force recompute "moe.router.gate" |
| 140 | + # This leads to two mms being recomputed since they share the same shape! |
| 141 | + model_with_force_first = TestModule().cuda() |
| 142 | + ac_config_with_force_first = ACConfig( |
| 143 | + mode="selective", |
| 144 | + selective_ac_option="op", |
| 145 | + per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], |
| 146 | + ) |
| 147 | + apply_ac(model_with_force_first, ac_config_with_force_first) |
| 148 | + mem_with_force_first = get_act_mem(model_with_force_first) |
| 149 | + |
| 150 | + # 4. Per-op SAC with force recompute "output" |
| 151 | + model_with_force_last = TestModule().cuda() |
| 152 | + ac_config_with_force_last = ACConfig( |
| 153 | + mode="selective", |
| 154 | + selective_ac_option="op", |
| 155 | + per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], |
| 156 | + ) |
| 157 | + apply_ac(model_with_force_last, ac_config_with_force_last) |
| 158 | + mem_with_force_last = get_act_mem(model_with_force_last) |
| 159 | + |
| 160 | + # 5. Full AC |
| 161 | + model_with_full_ac = TestModule().cuda() |
| 162 | + ac_config_full_ac = ACConfig( |
| 163 | + mode="full", |
| 164 | + ) |
| 165 | + apply_ac(model_with_full_ac, ac_config_full_ac) |
| 166 | + mem_full_ac = get_act_mem(model_with_full_ac) |
| 167 | + |
| 168 | + self.assertEqual(mem_no_ac, 2.0) |
| 169 | + self.assertEqual(mem_selective_ac, 3.0) |
| 170 | + self.assertEqual(mem_with_force_first, 2.0) |
| 171 | + self.assertEqual(mem_with_force_last, 1.0) |
| 172 | + self.assertEqual(mem_full_ac, 0.0) |
| 173 | + # Note: SAC > no-AC here because it unnecessarily saves "output" |
| 174 | + # even that is not needed for recomputaion and output is double |
| 175 | + # the size of the other two mms. |
| 176 | + |
| 177 | + def test_correctness(self): |
| 178 | + model_no_ac = TestModule() |
| 179 | + |
| 180 | + model_selective_ac = TestModule() |
| 181 | + model_selective_ac.load_state_dict(model_no_ac.state_dict()) |
| 182 | + apply_ac( |
| 183 | + model_selective_ac, |
| 184 | + ACConfig( |
| 185 | + mode="selective", |
| 186 | + selective_ac_option="op", |
| 187 | + per_op_sac_force_recompute_mm_shapes_by_fqns=[], |
| 188 | + ), |
| 189 | + ) |
| 190 | + model_force_first = TestModule() |
| 191 | + model_force_first.load_state_dict(model_no_ac.state_dict()) |
| 192 | + apply_ac( |
| 193 | + model_force_first, |
| 194 | + ACConfig( |
| 195 | + mode="selective", |
| 196 | + selective_ac_option="op", |
| 197 | + per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], |
| 198 | + ), |
| 199 | + ) |
| 200 | + |
| 201 | + model_force_last = TestModule() |
| 202 | + model_force_last.load_state_dict(model_no_ac.state_dict()) |
| 203 | + apply_ac( |
| 204 | + model_force_last, |
| 205 | + ACConfig( |
| 206 | + mode="selective", |
| 207 | + selective_ac_option="op", |
| 208 | + per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], |
| 209 | + ), |
| 210 | + ) |
| 211 | + |
| 212 | + def run_fwd_bwd(model, batch): |
| 213 | + model.zero_grad(set_to_none=True) |
| 214 | + xin = batch.clone().detach().requires_grad_(True) |
| 215 | + out = model(xin) # scalar |
| 216 | + out.backward() |
| 217 | + |
| 218 | + grad_in = xin.grad.detach().clone() |
| 219 | + grad_params = [ |
| 220 | + p.grad.detach().clone() if isinstance(p.grad, torch.Tensor) else None |
| 221 | + for p in model.parameters() |
| 222 | + ] |
| 223 | + return out.detach(), grad_in, grad_params |
| 224 | + |
| 225 | + batch = torch.randn(64, 512) |
| 226 | + |
| 227 | + out_ref, gin_ref, gparams_ref = run_fwd_bwd(model_no_ac, batch) |
| 228 | + out_sel, gin_sel, gparams_sel = run_fwd_bwd(model_selective_ac, batch) |
| 229 | + out_f1, gin_f1, gparams_f1 = run_fwd_bwd(model_force_first, batch) |
| 230 | + out_fl, gin_fl, gparams_fl = run_fwd_bwd(model_force_last, batch) |
| 231 | + |
| 232 | + for other_out in (out_sel, out_f1, out_fl): |
| 233 | + torch.testing.assert_close(out_ref, other_out) |
| 234 | + |
| 235 | + for other_gin in (gin_sel, gin_f1, gin_fl): |
| 236 | + torch.testing.assert_close(gin_ref, other_gin) |
| 237 | + |
| 238 | + for g_ref, g_sel, g_f1, g_fl in zip( |
| 239 | + gparams_ref, gparams_sel, gparams_f1, gparams_fl |
| 240 | + ): |
| 241 | + # Skip wrapper / missing grads |
| 242 | + if not ( |
| 243 | + torch.is_tensor(g_ref) |
| 244 | + and torch.is_tensor(g_sel) |
| 245 | + and torch.is_tensor(g_f1) |
| 246 | + and torch.is_tensor(g_fl) |
| 247 | + ): |
| 248 | + continue |
| 249 | + |
| 250 | + torch.testing.assert_close(g_ref, g_sel) |
| 251 | + torch.testing.assert_close(g_ref, g_f1) |
| 252 | + torch.testing.assert_close(g_ref, g_fl) |
| 253 | + |
| 254 | + |
| 255 | +if __name__ == "__main__": |
| 256 | + unittest.main() |
0 commit comments