12
12
import torch
13
13
import torch .nn as nn
14
14
from torch .distributed ._composable .replicate import replicate
15
+
15
16
from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import (
16
- checkpoint_wrapper as ptd_checkpoint_wrapper ,
17
+ ActivationWrapper ,
18
+ # checkpoint_wrapper as ptd_checkpoint_wrapper,
17
19
)
18
20
21
+
22
+ class CheckpointWrapper (ActivationWrapper ):
23
+ def __init__ (self , mod : torch .nn .Module , ** kwargs ):
24
+ super ().__init__ (mod )
25
+ self ._checkpoint_wrapped_module = mod
26
+ self ._policy_fn = kwargs .get ("policy_fn" , None )
27
+
28
+ def forward (self , * args , ** kwargs ):
29
+ from ac_experimental import apply_ac_policy
30
+
31
+ policy = self ._policy_fn if self ._policy_fn is not None else "recompute_all"
32
+
33
+ with apply_ac_policy (policy_fn = policy ):
34
+ return self ._checkpoint_wrapped_module (* args , ** kwargs )
35
+ # return apply_ac_policy_fn(
36
+ # self._checkpoint_wrapped_module, *args, **kwargs, policy_fn="recompute_all"
37
+ # )
38
+
39
+
40
+ def ptd_checkpoint_wrapper (mod , ** kwargs ):
41
+ return CheckpointWrapper (mod , ** kwargs )
42
+
43
+
19
44
from torch .distributed .device_mesh import DeviceMesh
20
45
from torch .distributed .fsdp import CPUOffloadPolicy , fully_shard , MixedPrecisionPolicy
21
46
from torch .distributed .tensor import Replicate , Shard
@@ -251,9 +276,39 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
251
276
create_selective_checkpoint_contexts ,
252
277
)
253
278
254
- def _get_custom_policy (meta ):
255
- def _custom_policy (ctx , func , * args , ** kwargs ):
256
- mode = "recompute" if ctx .is_recompute else "forward"
279
+ # def _get_custom_policy(meta):
280
+ # def _custom_policy(ctx, func, *args, **kwargs):
281
+ # mode = "recompute" if ctx.is_recompute else "forward"
282
+ # mm_count_key = f"{mode}_mm_count"
283
+ # if func == torch.ops.aten.mm.default:
284
+ # meta[mm_count_key] += 1
285
+ # # Saves output of all compute ops, except every second mm
286
+ # # to_save = func in _save_list and not (
287
+ # # func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
288
+ # # )
289
+ # return (
290
+ # CheckpointPolicy.MUST_SAVE
291
+ # if func in _save_list
292
+ # else CheckpointPolicy.PREFER_RECOMPUTE
293
+ # )
294
+
295
+ # return _custom_policy
296
+
297
+ # def selective_checkpointing_context_fn():
298
+ # meta = defaultdict(int)
299
+ # return create_selective_checkpoint_contexts(_get_custom_policy(meta))
300
+
301
+ # return ptd_checkpoint_wrapper(
302
+ # module,
303
+ # context_fn=selective_checkpointing_context_fn,
304
+ # preserve_rng_state=False,
305
+ # )
306
+
307
+ def _get_custom_policy ():
308
+ meta = defaultdict (int )
309
+
310
+ def _custom_policy (ctx , out , func , * args , ** kwargs ):
311
+ mode = "forward" # recompute" if ctx.is_recompute else "forward"
257
312
mm_count_key = f"{ mode } _mm_count"
258
313
if func == torch .ops .aten .mm .default :
259
314
meta [mm_count_key ] += 1
@@ -263,21 +318,17 @@ def _custom_policy(ctx, func, *args, **kwargs):
263
318
)
264
319
return (
265
320
CheckpointPolicy .MUST_SAVE
266
- if to_save
321
+ if func in _save_list
267
322
else CheckpointPolicy .PREFER_RECOMPUTE
268
323
)
269
324
270
325
return _custom_policy
271
326
272
- def selective_checkpointing_context_fn ():
273
- meta = defaultdict (int )
274
- return create_selective_checkpoint_contexts (_get_custom_policy (meta ))
275
-
276
327
return ptd_checkpoint_wrapper (
277
328
module ,
278
- context_fn = selective_checkpointing_context_fn ,
279
- preserve_rng_state = False ,
329
+ policy_fn = _get_custom_policy (),
280
330
)
331
+
281
332
elif use_layer_sac :
282
333
# Checkpoint every `ac_freq` of the modules passed to this function
283
334
ac_freq = int (ac_config .selective_ac_option )
0 commit comments