Skip to content

Commit fdbcbb3

Browse files
committed
[not for land] Use new AC
ghstack-source-id: dbec8c3 Pull-Request-resolved: #1294
1 parent ed2bbc0 commit fdbcbb3

File tree

1 file changed

+62
-11
lines changed

1 file changed

+62
-11
lines changed

torchtitan/models/llama3/parallelize_llama.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,35 @@
1212
import torch
1313
import torch.nn as nn
1414
from torch.distributed._composable.replicate import replicate
15+
1516
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
16-
checkpoint_wrapper as ptd_checkpoint_wrapper,
17+
ActivationWrapper,
18+
# checkpoint_wrapper as ptd_checkpoint_wrapper,
1719
)
1820

21+
22+
class CheckpointWrapper(ActivationWrapper):
23+
def __init__(self, mod: torch.nn.Module, **kwargs):
24+
super().__init__(mod)
25+
self._checkpoint_wrapped_module = mod
26+
self._policy_fn = kwargs.get("policy_fn", None)
27+
28+
def forward(self, *args, **kwargs):
29+
from ac_experimental import apply_ac_policy
30+
31+
policy = self._policy_fn if self._policy_fn is not None else "recompute_all"
32+
33+
with apply_ac_policy(policy_fn=policy):
34+
return self._checkpoint_wrapped_module(*args, **kwargs)
35+
# return apply_ac_policy_fn(
36+
# self._checkpoint_wrapped_module, *args, **kwargs, policy_fn="recompute_all"
37+
# )
38+
39+
40+
def ptd_checkpoint_wrapper(mod, **kwargs):
41+
return CheckpointWrapper(mod, **kwargs)
42+
43+
1944
from torch.distributed.device_mesh import DeviceMesh
2045
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
2146
from torch.distributed.tensor import Replicate, Shard
@@ -251,9 +276,39 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
251276
create_selective_checkpoint_contexts,
252277
)
253278

254-
def _get_custom_policy(meta):
255-
def _custom_policy(ctx, func, *args, **kwargs):
256-
mode = "recompute" if ctx.is_recompute else "forward"
279+
# def _get_custom_policy(meta):
280+
# def _custom_policy(ctx, func, *args, **kwargs):
281+
# mode = "recompute" if ctx.is_recompute else "forward"
282+
# mm_count_key = f"{mode}_mm_count"
283+
# if func == torch.ops.aten.mm.default:
284+
# meta[mm_count_key] += 1
285+
# # Saves output of all compute ops, except every second mm
286+
# # to_save = func in _save_list and not (
287+
# # func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
288+
# # )
289+
# return (
290+
# CheckpointPolicy.MUST_SAVE
291+
# if func in _save_list
292+
# else CheckpointPolicy.PREFER_RECOMPUTE
293+
# )
294+
295+
# return _custom_policy
296+
297+
# def selective_checkpointing_context_fn():
298+
# meta = defaultdict(int)
299+
# return create_selective_checkpoint_contexts(_get_custom_policy(meta))
300+
301+
# return ptd_checkpoint_wrapper(
302+
# module,
303+
# context_fn=selective_checkpointing_context_fn,
304+
# preserve_rng_state=False,
305+
# )
306+
307+
def _get_custom_policy():
308+
meta = defaultdict(int)
309+
310+
def _custom_policy(ctx, out, func, *args, **kwargs):
311+
mode = "forward" # recompute" if ctx.is_recompute else "forward"
257312
mm_count_key = f"{mode}_mm_count"
258313
if func == torch.ops.aten.mm.default:
259314
meta[mm_count_key] += 1
@@ -263,21 +318,17 @@ def _custom_policy(ctx, func, *args, **kwargs):
263318
)
264319
return (
265320
CheckpointPolicy.MUST_SAVE
266-
if to_save
321+
if func in _save_list
267322
else CheckpointPolicy.PREFER_RECOMPUTE
268323
)
269324

270325
return _custom_policy
271326

272-
def selective_checkpointing_context_fn():
273-
meta = defaultdict(int)
274-
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
275-
276327
return ptd_checkpoint_wrapper(
277328
module,
278-
context_fn=selective_checkpointing_context_fn,
279-
preserve_rng_state=False,
329+
policy_fn=_get_custom_policy(),
280330
)
331+
281332
elif use_layer_sac:
282333
# Checkpoint every `ac_freq` of the modules passed to this function
283334
ac_freq = int(ac_config.selective_ac_option)

0 commit comments

Comments
 (0)