Skip to content

Commit f0ce21b

Browse files
[SimpleFSDP] Add support for hsdp+tp (#1343)
As titled, this pr adds support for SimpleFSDP's HSDP + TP The profile trace below shows the three streams for FSDP's All-gather/Reduce-scatter; DDP's All-reduce; and TP's communications. <img width="1499" alt="Screenshot 2025-06-25 at 8 13 39 PM" src="https://github.com/user-attachments/assets/4d9a7561-8895-48e6-b728-bcc1f95d8f2c" /> The loss below shows SimpleFSDP & FSDP2's losses matches under HSDP + TP mode (seed=42). <img width="1521" alt="Screenshot 2025-06-25 at 8 27 11 PM" src="https://github.com/user-attachments/assets/b497c1f9-a766-4a00-9e42-cbc8ce1ca1b8" />
1 parent acd5ba8 commit f0ce21b

File tree

2 files changed

+40
-25
lines changed

2 files changed

+40
-25
lines changed

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _distribute_dtensor(
5858
Below are experimental enhancements to distribute a DTensor.
5959
This helps enable Simple FSDP + TP, in which
6060
inner spec/mesh is TP spec/mesh
61-
outer spec/mesh is FSDP spec/mesh
61+
outer spec/mesh is FSDP/DDP/HSDP spec/mesh
6262
The logic follows
6363
https://github.com/pytorch/pytorch/blob/main/torch/distributed/_composable/fsdp/_fsdp_param.py#L261
6464
"""
@@ -78,24 +78,40 @@ def _distribute_dtensor(
7878
submesh_names = outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names
7979
spanned_mesh = outer_global_mesh[submesh_names]
8080

81-
if placements[0].is_shard():
82-
# for FSDP + TP dtensor placement
83-
shard_dim = placements[0].dim
81+
if len(placements) == 1:
82+
assert placements[0].is_replicate() or placements[0].is_shard()
83+
if placements[0].is_shard():
84+
# For FSDP + TP dtensor placement
85+
shard_dim = placements[0].dim
86+
split_factor = inner_spec.num_shards_map[shard_dim]
87+
tensor_placement = (
88+
(
89+
_StridedShard(shard_dim, split_factor=split_factor)
90+
if split_factor > 1
91+
else placements[0]
92+
),
93+
inner_spec.placements[0],
94+
)
95+
else:
96+
# For DDP + TP dtensor placement
97+
tensor_placement = (placements[0], inner_spec.placements[0])
98+
elif len(placements) == 2:
99+
assert placements[0].is_replicate() and placements[1].is_shard()
100+
# For HSDP + TP dtensor placement
101+
shard_dim = placements[1].dim
84102
split_factor = inner_spec.num_shards_map[shard_dim]
85103
tensor_placement = (
104+
placements[0],
86105
(
87106
_StridedShard(shard_dim, split_factor=split_factor)
88107
if split_factor > 1
89-
else placements[0]
108+
else placements[1]
90109
),
91110
inner_spec.placements[0],
92111
)
93-
elif placements[0].is_replicate():
94-
# for DDP + TP dtensor placement
95-
tensor_placement = (placements[0], inner_spec.placements[0])
96112
else:
97113
raise ValueError(
98-
f"Unsupported placement {placements[0]} for distributing DTensor {tensor}"
114+
f"Unsupported placement {placements} for distributing DTensor {tensor}"
99115
)
100116

101117
current_spec = DTensorSpec(
@@ -105,7 +121,7 @@ def _distribute_dtensor(
105121
)
106122
target_spec = DTensorSpec(
107123
mesh=outer_mesh,
108-
placements=(placements[0],),
124+
placements=(placements[-1],),
109125
tensor_meta=inner_spec.tensor_meta,
110126
)
111127
result_tensor = redistribute_local_tensor(
@@ -188,9 +204,9 @@ def replicate_compute(self, x):
188204
# the gradients are partial tensors that needs to perform reduction
189205
# (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both)
190206

191-
# support for FSDP/DDP + TP (assuming TP shards the inner-most dim)
207+
# support for FSDP/DDP/HSDP + TP (assuming TP shards the inner-most dim)
192208
if x._spec.mesh.mesh_dim_names[-1] == "tp":
193-
dp_placement, tp_placement = x._spec.placements
209+
tp_placement = x._spec.placements[-1]
194210
# TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"]
195211
# after DeviceMesh supports slicing a non-root mesh
196212
# dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"]

torchtitan/experiments/simple_fsdp/tests/integration_tests.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -130,19 +130,18 @@ def build_test_list():
130130
"hsdp",
131131
ngpu=4,
132132
),
133-
# TODO: Adds back after HSDP+TP is supported by SimpleFSDP
134-
# OverrideDefinitions(
135-
# [
136-
# [
137-
# "--parallelism.data_parallel_shard_degree=2",
138-
# "--parallelism.data_parallel_replicate_degree=2",
139-
# "--parallelism.tensor_parallel_degree=2",
140-
# ]
141-
# ],
142-
# "HSDP+TP",
143-
# "hsdp+tp",
144-
# ngpu=8,
145-
# ),
133+
OverrideDefinitions(
134+
[
135+
[
136+
"--parallelism.data_parallel_shard_degree=2",
137+
"--parallelism.data_parallel_replicate_degree=2",
138+
"--parallelism.tensor_parallel_degree=2",
139+
]
140+
],
141+
"HSDP+TP",
142+
"hsdp+tp",
143+
ngpu=8,
144+
),
146145
OverrideDefinitions(
147146
[
148147
[

0 commit comments

Comments
 (0)