Skip to content

Commit b7c7ed7

Browse files
[SimpleFSDP] Add support for SimpleFSDP DCP (#1273)
As titled, this pr adds support for SimpleFSDP's DCP composability. The code is based on previous implementations from @fmassa @awgu @yf225 The following losses match perfectly. The checkpoints are loaded from step-110. (1) [dp:4] --> [dp:4] <img width="1511" alt="Screenshot 2025-06-08 at 5 39 50 PM" src="https://github.com/user-attachments/assets/a0e49c0f-fdbc-4ea8-82bd-573f0c9015a8" /> (2) [dp:2, tp:2] --> [dp:2, tp:2] & [dp:2, pp:2] <img width="1534" alt="Screenshot 2025-06-08 at 5 46 29 PM" src="https://github.com/user-attachments/assets/920a9898-fc75-4dfe-bf28-d597be398226" />
1 parent bc5ebb7 commit b7c7ed7

File tree

3 files changed

+89
-57
lines changed

3 files changed

+89
-57
lines changed

torchtitan/experiments/simple_fsdp/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Some of the features require the updates from PyTorch, with which we are working
2424
|Tensor Parallelism||
2525
|Context Parallelism||
2626
|Pipeline Parallelism||
27-
|Distributed Checkpointing| 🚧 |
27+
|Distributed Checkpointing| |
2828
|Float8 Training| 🚧 |
2929

3030

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Sequence
88
from contextlib import contextmanager
99
from dataclasses import dataclass
10-
from typing import Optional
10+
from typing import List, Optional
1111

1212
import torch
1313
import torch.nn as nn
@@ -124,6 +124,26 @@ def _distribute_dtensor(
124124
)
125125

126126

127+
def _register_parametrization(
128+
module: nn.Module, param_names: List[str], parametrization: nn.Module
129+
):
130+
"""
131+
it works with state_dict without incurring parametrization calls because
132+
state_dict accesses parameters directly from self._parameters, not from getters
133+
https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L2141
134+
"""
135+
param_name_to_property = {
136+
param_name: property(lambda self: parametrization(self._parameters[param_name]))
137+
for param_name in param_names
138+
}
139+
module_cls = type(
140+
f"FSDP{module.__class__.__name__}",
141+
(module.__class__,),
142+
param_name_to_property,
143+
)
144+
module.__class__ = module_cls
145+
146+
127147
def fsdp_policy():
128148
def _fsdp_recomp_policy():
129149
def _custom_policy(ctx, func, *args, **kwargs):
@@ -263,18 +283,32 @@ def data_parallel(
263283
distribute_tensor_func(p, device_mesh, param_sharding)
264284
),
265285
)
266-
nn.utils.parametrize.register_parametrization(
267-
mod,
268-
p_name,
269-
ReplicateComputation(
270-
device_mesh,
271-
param_sharding,
272-
mode,
273-
regional_ac,
274-
mp_policy=mp_policy,
275-
tp_mesh=tp_mesh,
276-
),
277-
unsafe=True,
278-
)
279-
286+
# to be compatible with DCP, we use a customized _register_parametrization
287+
# instead of nn.utils.parametrize.register_parametrization here
288+
# nn.utils.parametrize.register_parametrization(
289+
# mod,
290+
# p_name,
291+
# ReplicateComputation(
292+
# device_mesh,
293+
# param_sharding,
294+
# mode,
295+
# regional_ac,
296+
# mp_policy=mp_policy,
297+
# tp_mesh=tp_mesh,
298+
# ),
299+
# unsafe=True,
300+
# )
301+
302+
_register_parametrization(
303+
mod,
304+
list(params_dict.keys()),
305+
ReplicateComputation(
306+
device_mesh,
307+
param_sharding,
308+
mode,
309+
regional_ac,
310+
mp_policy=mp_policy,
311+
tp_mesh=tp_mesh,
312+
),
313+
)
280314
return model

torchtitan/experiments/simple_fsdp/tests/integration_tests.py

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -75,20 +75,19 @@ def build_test_list():
7575
# "2D async TP",
7676
# "2d_asynctp",
7777
# ),
78-
# TODO: Adds back after DCP is supported by SimpleFSDP
79-
# OverrideDefinitions(
80-
# [
81-
# [
82-
# "--checkpoint.enable_checkpoint",
83-
# ],
84-
# [
85-
# "--checkpoint.enable_checkpoint",
86-
# "--training.steps 20",
87-
# ],
88-
# ],
89-
# "Checkpoint Integration Test - Save Load Full Checkpoint",
90-
# "full_checkpoint",
91-
# ),
78+
OverrideDefinitions(
79+
[
80+
[
81+
"--checkpoint.enable_checkpoint",
82+
],
83+
[
84+
"--checkpoint.enable_checkpoint",
85+
"--training.steps 20",
86+
],
87+
],
88+
"Checkpoint Integration Test - Save Load Full Checkpoint",
89+
"full_checkpoint",
90+
),
9291
OverrideDefinitions(
9392
[
9493
[
@@ -179,33 +178,32 @@ def build_test_list():
179178
"fsdp+tp+cp",
180179
ngpu=8,
181180
),
182-
# TODO: Adds back after DCP is supported by SimpleFSDP
183-
# OverrideDefinitions(
184-
# [
185-
# [
186-
# "--checkpoint.enable_checkpoint",
187-
# "--training.steps 10",
188-
# ],
189-
# # Save at [dp:4] and load at [dp:2, tp:2]. Note that the dataloader should be
190-
# # excluded during loading to avoid errors caused by mismatched dp_degree.
191-
# [
192-
# "--checkpoint.enable_checkpoint",
193-
# "--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer",
194-
# "--parallelism.tensor_parallel_degree 2",
195-
# "--training.steps 20",
196-
# ],
197-
# # load at [tp:4].
198-
# [
199-
# "--checkpoint.enable_checkpoint",
200-
# "--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer",
201-
# "--parallelism.tensor_parallel_degree 4",
202-
# "--training.steps 30",
203-
# ],
204-
# ],
205-
# "Optional checkpoint",
206-
# "optional_checkpoint",
207-
# ngpu=4,
208-
# ),
181+
OverrideDefinitions(
182+
[
183+
[
184+
"--checkpoint.enable_checkpoint",
185+
"--training.steps 10",
186+
],
187+
# Save at [dp:4] and load at [dp:2, tp:2]. Note that the dataloader should be
188+
# excluded during loading to avoid errors caused by mismatched dp_degree.
189+
[
190+
"--checkpoint.enable_checkpoint",
191+
"--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer",
192+
"--parallelism.tensor_parallel_degree 2",
193+
"--training.steps 20",
194+
],
195+
# load at [tp:4].
196+
[
197+
"--checkpoint.enable_checkpoint",
198+
"--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer",
199+
"--parallelism.tensor_parallel_degree 4",
200+
"--training.steps 30",
201+
],
202+
],
203+
"Optional checkpoint",
204+
"optional_checkpoint",
205+
ngpu=4,
206+
),
209207
]
210208
return integration_tests_flavors
211209

0 commit comments

Comments
 (0)