@@ -58,7 +58,7 @@ def _distribute_dtensor(
58
58
Below are experimental enhancements to distribute a DTensor.
59
59
This helps enable Simple FSDP + TP, in which
60
60
inner spec/mesh is TP spec/mesh
61
- outer spec/mesh is FSDP spec/mesh
61
+ outer spec/mesh is FSDP/DDP/HSDP spec/mesh
62
62
The logic follows
63
63
https://github.com/pytorch/pytorch/blob/main/torch/distributed/_composable/fsdp/_fsdp_param.py#L261
64
64
"""
@@ -78,24 +78,40 @@ def _distribute_dtensor(
78
78
submesh_names = outer_mesh .mesh_dim_names + inner_mesh .mesh_dim_names
79
79
spanned_mesh = outer_global_mesh [submesh_names ]
80
80
81
- if placements [0 ].is_shard ():
82
- # for FSDP + TP dtensor placement
83
- shard_dim = placements [0 ].dim
81
+ if len (placements ) == 1 :
82
+ assert placements [0 ].is_replicate () or placements [0 ].is_shard ()
83
+ if placements [0 ].is_shard ():
84
+ # For FSDP + TP dtensor placement
85
+ shard_dim = placements [0 ].dim
86
+ split_factor = inner_spec .num_shards_map [shard_dim ]
87
+ tensor_placement = (
88
+ (
89
+ _StridedShard (shard_dim , split_factor = split_factor )
90
+ if split_factor > 1
91
+ else placements [0 ]
92
+ ),
93
+ inner_spec .placements [0 ],
94
+ )
95
+ else :
96
+ # For DDP + TP dtensor placement
97
+ tensor_placement = (placements [0 ], inner_spec .placements [0 ])
98
+ elif len (placements ) == 2 :
99
+ assert placements [0 ].is_replicate () and placements [1 ].is_shard ()
100
+ # For HSDP + TP dtensor placement
101
+ shard_dim = placements [1 ].dim
84
102
split_factor = inner_spec .num_shards_map [shard_dim ]
85
103
tensor_placement = (
104
+ placements [0 ],
86
105
(
87
106
_StridedShard (shard_dim , split_factor = split_factor )
88
107
if split_factor > 1
89
- else placements [0 ]
108
+ else placements [1 ]
90
109
),
91
110
inner_spec .placements [0 ],
92
111
)
93
- elif placements [0 ].is_replicate ():
94
- # for DDP + TP dtensor placement
95
- tensor_placement = (placements [0 ], inner_spec .placements [0 ])
96
112
else :
97
113
raise ValueError (
98
- f"Unsupported placement { placements [ 0 ] } for distributing DTensor { tensor } "
114
+ f"Unsupported placement { placements } for distributing DTensor { tensor } "
99
115
)
100
116
101
117
current_spec = DTensorSpec (
@@ -105,7 +121,7 @@ def _distribute_dtensor(
105
121
)
106
122
target_spec = DTensorSpec (
107
123
mesh = outer_mesh ,
108
- placements = (placements [0 ],),
124
+ placements = (placements [- 1 ],),
109
125
tensor_meta = inner_spec .tensor_meta ,
110
126
)
111
127
result_tensor = redistribute_local_tensor (
@@ -188,9 +204,9 @@ def replicate_compute(self, x):
188
204
# the gradients are partial tensors that needs to perform reduction
189
205
# (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both)
190
206
191
- # support for FSDP/DDP + TP (assuming TP shards the inner-most dim)
207
+ # support for FSDP/DDP/HSDP + TP (assuming TP shards the inner-most dim)
192
208
if x ._spec .mesh .mesh_dim_names [- 1 ] == "tp" :
193
- dp_placement , tp_placement = x ._spec .placements
209
+ tp_placement = x ._spec .placements [ - 1 ]
194
210
# TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"]
195
211
# after DeviceMesh supports slicing a non-root mesh
196
212
# dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"]
0 commit comments