8
8
# pyre-strict
9
9
10
10
import copy
11
- import hashlib
12
11
import logging
13
12
import time
14
13
from functools import reduce
@@ -143,33 +142,24 @@ def _merge_plans(best_plans: List[ShardingPlan]) -> ShardingPlan:
143
142
return merged_plan
144
143
145
144
146
- class EmbeddingShardingPlanner (ShardingPlanner ):
145
+ class EmbeddingPlannerBase (ShardingPlanner ):
147
146
"""
148
- Provides an optimized sharding plan for a given module with shardable parameters
149
- according to the provided sharders, topology, and constraints .
147
+ Base class for embedding sharding planners that provides common initialization
148
+ and shared functionality .
150
149
151
150
Args:
152
151
topology (Optional[Topology]): the topology of the current process group.
153
152
batch_size (Optional[int]): the batch size of the model.
154
153
enumerator (Optional[Enumerator]): the enumerator to use
155
154
storage_reservation (Optional[StorageReservation]): the storage reservation to use
156
- proposer (Optional[Union[Proposer, List[Proposer]]]): the proposer(s) to use
157
- partitioner (Optional[Partitioner]): the partitioner to use
158
- performance_model (Optional[PerfModel]): the performance model to use
159
155
stats (Optional[Union[Stats, List[Stats]]]): the stats to use
160
156
constraints (Optional[Dict[str, ParameterConstraints]]): per table constraints
161
157
for sharding.
162
158
debug (bool): whether to print debug information.
163
-
164
- Example::
165
-
166
- ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta"))
167
- planner = EmbeddingShardingPlanner()
168
- plan = planner.plan(
169
- module=ebc,
170
- sharders=[EmbeddingBagCollectionSharder()],
171
- )
172
-
159
+ callbacks (Optional[List[Callable[[List[ShardingOption]], List[ShardingOption]]]):
160
+ callback functions to apply to plans.
161
+ timeout_seconds (Optional[int]): timeout for planning in seconds.
162
+ heuristical_storage_reservation_percentage (float): percentage of storage to reserve for sparse archs.
173
163
"""
174
164
175
165
def __init__ (
@@ -178,16 +168,14 @@ def __init__(
178
168
batch_size : Optional [int ] = None ,
179
169
enumerator : Optional [Enumerator ] = None ,
180
170
storage_reservation : Optional [StorageReservation ] = None ,
181
- proposer : Optional [Union [Proposer , List [Proposer ]]] = None ,
182
- partitioner : Optional [Partitioner ] = None ,
183
- performance_model : Optional [PerfModel ] = None ,
184
171
stats : Optional [Union [Stats , List [Stats ]]] = None ,
185
172
constraints : Optional [Dict [str , ParameterConstraints ]] = None ,
186
173
debug : bool = True ,
187
174
callbacks : Optional [
188
175
List [Callable [[List [ShardingOption ]], List [ShardingOption ]]]
189
176
] = None ,
190
177
timeout_seconds : Optional [int ] = None ,
178
+ heuristical_storage_reservation_percentage : float = 0.15 ,
191
179
) -> None :
192
180
if topology is None :
193
181
topology = Topology (
@@ -210,7 +198,116 @@ def __init__(
210
198
self ._storage_reservation : StorageReservation = (
211
199
storage_reservation
212
200
if storage_reservation
213
- else HeuristicalStorageReservation (percentage = 0.15 )
201
+ else HeuristicalStorageReservation (
202
+ percentage = heuristical_storage_reservation_percentage
203
+ )
204
+ )
205
+
206
+ if stats is not None :
207
+ self ._stats : List [Stats ] = [stats ] if not isinstance (stats , list ) else stats
208
+ else :
209
+ self ._stats = [EmbeddingStats ()]
210
+
211
+ self ._debug = debug
212
+ self ._callbacks : List [
213
+ Callable [[List [ShardingOption ]], List [ShardingOption ]]
214
+ ] = ([] if callbacks is None else callbacks )
215
+ if timeout_seconds is not None :
216
+ assert timeout_seconds > 0 , "Timeout must be positive"
217
+ self ._timeout_seconds = timeout_seconds
218
+
219
+ def collective_plan (
220
+ self ,
221
+ module : nn .Module ,
222
+ sharders : Optional [List [ModuleSharder [nn .Module ]]] = None ,
223
+ pg : Optional [dist .ProcessGroup ] = None ,
224
+ ) -> ShardingPlan :
225
+ """
226
+ Call self.plan(...) on rank 0 and broadcast
227
+
228
+ Args:
229
+ module (nn.Module): the module to shard.
230
+ sharders (Optional[List[ModuleSharder[nn.Module]]]): the sharders to use for sharding
231
+ pg (Optional[dist.ProcessGroup]): the process group to use for collective operations
232
+
233
+ Returns:
234
+ ShardingPlan: the sharding plan for the module.
235
+ """
236
+ if pg is None :
237
+ assert dist .is_initialized (), (
238
+ "The default process group is not yet initialized. "
239
+ "Please call torch.distributed.init_process_group() first before invoking this. "
240
+ "If you are not within a distributed environment, use the single rank version plan() instead."
241
+ )
242
+ pg = none_throws (dist .GroupMember .WORLD )
243
+
244
+ if sharders is None :
245
+ sharders = get_default_sharders ()
246
+ return invoke_on_rank_and_broadcast_result (
247
+ pg ,
248
+ 0 ,
249
+ self .plan ,
250
+ module ,
251
+ sharders ,
252
+ )
253
+
254
+
255
+ class EmbeddingShardingPlanner (EmbeddingPlannerBase ):
256
+ """
257
+ Provides an optimized sharding plan for a given module with shardable parameters
258
+ according to the provided sharders, topology, and constraints.
259
+
260
+ Args:
261
+ topology (Optional[Topology]): the topology of the current process group.
262
+ batch_size (Optional[int]): the batch size of the model.
263
+ enumerator (Optional[Enumerator]): the enumerator to use
264
+ storage_reservation (Optional[StorageReservation]): the storage reservation to use
265
+ proposer (Optional[Union[Proposer, List[Proposer]]]): the proposer(s) to use
266
+ partitioner (Optional[Partitioner]): the partitioner to use
267
+ performance_model (Optional[PerfModel]): the performance model to use
268
+ stats (Optional[Union[Stats, List[Stats]]]): the stats to use
269
+ constraints (Optional[Dict[str, ParameterConstraints]]): per table constraints
270
+ for sharding.
271
+ debug (bool): whether to print debug information.
272
+
273
+ Example::
274
+
275
+ ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta"))
276
+ planner = EmbeddingShardingPlanner()
277
+ plan = planner.plan(
278
+ module=ebc,
279
+ sharders=[EmbeddingBagCollectionSharder()],
280
+ )
281
+
282
+ """
283
+
284
+ def __init__ (
285
+ self ,
286
+ topology : Optional [Topology ] = None ,
287
+ batch_size : Optional [int ] = None ,
288
+ enumerator : Optional [Enumerator ] = None ,
289
+ storage_reservation : Optional [StorageReservation ] = None ,
290
+ proposer : Optional [Union [Proposer , List [Proposer ]]] = None ,
291
+ partitioner : Optional [Partitioner ] = None ,
292
+ performance_model : Optional [PerfModel ] = None ,
293
+ stats : Optional [Union [Stats , List [Stats ]]] = None ,
294
+ constraints : Optional [Dict [str , ParameterConstraints ]] = None ,
295
+ debug : bool = True ,
296
+ callbacks : Optional [
297
+ List [Callable [[List [ShardingOption ]], List [ShardingOption ]]]
298
+ ] = None ,
299
+ timeout_seconds : Optional [int ] = None ,
300
+ ) -> None :
301
+ super ().__init__ (
302
+ topology = topology ,
303
+ batch_size = batch_size ,
304
+ enumerator = enumerator ,
305
+ storage_reservation = storage_reservation ,
306
+ stats = stats ,
307
+ constraints = constraints ,
308
+ debug = debug ,
309
+ callbacks = callbacks ,
310
+ timeout_seconds = timeout_seconds ,
214
311
)
215
312
self ._partitioner : Partitioner = (
216
313
partitioner if partitioner else GreedyPerfPartitioner ()
@@ -227,24 +324,14 @@ def __init__(
227
324
UniformProposer (),
228
325
]
229
326
self ._perf_model : PerfModel = (
230
- performance_model if performance_model else NoopPerfModel (topology = topology )
327
+ performance_model
328
+ if performance_model
329
+ else NoopPerfModel (topology = self ._topology )
231
330
)
232
331
233
- if stats is not None :
234
- self ._stats : List [Stats ] = [stats ] if not isinstance (stats , list ) else stats
235
- else :
236
- self ._stats = [EmbeddingStats ()]
237
-
238
- self ._debug = debug
239
332
self ._num_proposals : int = 0
240
333
self ._num_plans : int = 0
241
334
self ._best_plan : Optional [List [ShardingOption ]] = None
242
- self ._callbacks : List [
243
- Callable [[List [ShardingOption ]], List [ShardingOption ]]
244
- ] = ([] if callbacks is None else callbacks )
245
- if timeout_seconds is not None :
246
- assert timeout_seconds > 0 , "Timeout must be positive"
247
- self ._timeout_seconds = timeout_seconds
248
335
249
336
def collective_plan (
250
337
self ,
0 commit comments