Skip to content

Commit 0a39b97

Browse files
author
pytorchbot
committed
2025-07-24 nightly release (145441b)
1 parent 3270b93 commit 0a39b97

File tree

4 files changed

+121
-42
lines changed

4 files changed

+121
-42
lines changed

torchrec/distributed/planner/planners.py

Lines changed: 121 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
# pyre-strict
99

1010
import copy
11-
import hashlib
1211
import logging
1312
import time
1413
from functools import reduce
@@ -143,33 +142,24 @@ def _merge_plans(best_plans: List[ShardingPlan]) -> ShardingPlan:
143142
return merged_plan
144143

145144

146-
class EmbeddingShardingPlanner(ShardingPlanner):
145+
class EmbeddingPlannerBase(ShardingPlanner):
147146
"""
148-
Provides an optimized sharding plan for a given module with shardable parameters
149-
according to the provided sharders, topology, and constraints.
147+
Base class for embedding sharding planners that provides common initialization
148+
and shared functionality.
150149
151150
Args:
152151
topology (Optional[Topology]): the topology of the current process group.
153152
batch_size (Optional[int]): the batch size of the model.
154153
enumerator (Optional[Enumerator]): the enumerator to use
155154
storage_reservation (Optional[StorageReservation]): the storage reservation to use
156-
proposer (Optional[Union[Proposer, List[Proposer]]]): the proposer(s) to use
157-
partitioner (Optional[Partitioner]): the partitioner to use
158-
performance_model (Optional[PerfModel]): the performance model to use
159155
stats (Optional[Union[Stats, List[Stats]]]): the stats to use
160156
constraints (Optional[Dict[str, ParameterConstraints]]): per table constraints
161157
for sharding.
162158
debug (bool): whether to print debug information.
163-
164-
Example::
165-
166-
ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta"))
167-
planner = EmbeddingShardingPlanner()
168-
plan = planner.plan(
169-
module=ebc,
170-
sharders=[EmbeddingBagCollectionSharder()],
171-
)
172-
159+
callbacks (Optional[List[Callable[[List[ShardingOption]], List[ShardingOption]]]):
160+
callback functions to apply to plans.
161+
timeout_seconds (Optional[int]): timeout for planning in seconds.
162+
heuristical_storage_reservation_percentage (float): percentage of storage to reserve for sparse archs.
173163
"""
174164

175165
def __init__(
@@ -178,16 +168,14 @@ def __init__(
178168
batch_size: Optional[int] = None,
179169
enumerator: Optional[Enumerator] = None,
180170
storage_reservation: Optional[StorageReservation] = None,
181-
proposer: Optional[Union[Proposer, List[Proposer]]] = None,
182-
partitioner: Optional[Partitioner] = None,
183-
performance_model: Optional[PerfModel] = None,
184171
stats: Optional[Union[Stats, List[Stats]]] = None,
185172
constraints: Optional[Dict[str, ParameterConstraints]] = None,
186173
debug: bool = True,
187174
callbacks: Optional[
188175
List[Callable[[List[ShardingOption]], List[ShardingOption]]]
189176
] = None,
190177
timeout_seconds: Optional[int] = None,
178+
heuristical_storage_reservation_percentage: float = 0.15,
191179
) -> None:
192180
if topology is None:
193181
topology = Topology(
@@ -210,7 +198,116 @@ def __init__(
210198
self._storage_reservation: StorageReservation = (
211199
storage_reservation
212200
if storage_reservation
213-
else HeuristicalStorageReservation(percentage=0.15)
201+
else HeuristicalStorageReservation(
202+
percentage=heuristical_storage_reservation_percentage
203+
)
204+
)
205+
206+
if stats is not None:
207+
self._stats: List[Stats] = [stats] if not isinstance(stats, list) else stats
208+
else:
209+
self._stats = [EmbeddingStats()]
210+
211+
self._debug = debug
212+
self._callbacks: List[
213+
Callable[[List[ShardingOption]], List[ShardingOption]]
214+
] = ([] if callbacks is None else callbacks)
215+
if timeout_seconds is not None:
216+
assert timeout_seconds > 0, "Timeout must be positive"
217+
self._timeout_seconds = timeout_seconds
218+
219+
def collective_plan(
220+
self,
221+
module: nn.Module,
222+
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
223+
pg: Optional[dist.ProcessGroup] = None,
224+
) -> ShardingPlan:
225+
"""
226+
Call self.plan(...) on rank 0 and broadcast
227+
228+
Args:
229+
module (nn.Module): the module to shard.
230+
sharders (Optional[List[ModuleSharder[nn.Module]]]): the sharders to use for sharding
231+
pg (Optional[dist.ProcessGroup]): the process group to use for collective operations
232+
233+
Returns:
234+
ShardingPlan: the sharding plan for the module.
235+
"""
236+
if pg is None:
237+
assert dist.is_initialized(), (
238+
"The default process group is not yet initialized. "
239+
"Please call torch.distributed.init_process_group() first before invoking this. "
240+
"If you are not within a distributed environment, use the single rank version plan() instead."
241+
)
242+
pg = none_throws(dist.GroupMember.WORLD)
243+
244+
if sharders is None:
245+
sharders = get_default_sharders()
246+
return invoke_on_rank_and_broadcast_result(
247+
pg,
248+
0,
249+
self.plan,
250+
module,
251+
sharders,
252+
)
253+
254+
255+
class EmbeddingShardingPlanner(EmbeddingPlannerBase):
256+
"""
257+
Provides an optimized sharding plan for a given module with shardable parameters
258+
according to the provided sharders, topology, and constraints.
259+
260+
Args:
261+
topology (Optional[Topology]): the topology of the current process group.
262+
batch_size (Optional[int]): the batch size of the model.
263+
enumerator (Optional[Enumerator]): the enumerator to use
264+
storage_reservation (Optional[StorageReservation]): the storage reservation to use
265+
proposer (Optional[Union[Proposer, List[Proposer]]]): the proposer(s) to use
266+
partitioner (Optional[Partitioner]): the partitioner to use
267+
performance_model (Optional[PerfModel]): the performance model to use
268+
stats (Optional[Union[Stats, List[Stats]]]): the stats to use
269+
constraints (Optional[Dict[str, ParameterConstraints]]): per table constraints
270+
for sharding.
271+
debug (bool): whether to print debug information.
272+
273+
Example::
274+
275+
ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta"))
276+
planner = EmbeddingShardingPlanner()
277+
plan = planner.plan(
278+
module=ebc,
279+
sharders=[EmbeddingBagCollectionSharder()],
280+
)
281+
282+
"""
283+
284+
def __init__(
285+
self,
286+
topology: Optional[Topology] = None,
287+
batch_size: Optional[int] = None,
288+
enumerator: Optional[Enumerator] = None,
289+
storage_reservation: Optional[StorageReservation] = None,
290+
proposer: Optional[Union[Proposer, List[Proposer]]] = None,
291+
partitioner: Optional[Partitioner] = None,
292+
performance_model: Optional[PerfModel] = None,
293+
stats: Optional[Union[Stats, List[Stats]]] = None,
294+
constraints: Optional[Dict[str, ParameterConstraints]] = None,
295+
debug: bool = True,
296+
callbacks: Optional[
297+
List[Callable[[List[ShardingOption]], List[ShardingOption]]]
298+
] = None,
299+
timeout_seconds: Optional[int] = None,
300+
) -> None:
301+
super().__init__(
302+
topology=topology,
303+
batch_size=batch_size,
304+
enumerator=enumerator,
305+
storage_reservation=storage_reservation,
306+
stats=stats,
307+
constraints=constraints,
308+
debug=debug,
309+
callbacks=callbacks,
310+
timeout_seconds=timeout_seconds,
214311
)
215312
self._partitioner: Partitioner = (
216313
partitioner if partitioner else GreedyPerfPartitioner()
@@ -227,24 +324,14 @@ def __init__(
227324
UniformProposer(),
228325
]
229326
self._perf_model: PerfModel = (
230-
performance_model if performance_model else NoopPerfModel(topology=topology)
327+
performance_model
328+
if performance_model
329+
else NoopPerfModel(topology=self._topology)
231330
)
232331

233-
if stats is not None:
234-
self._stats: List[Stats] = [stats] if not isinstance(stats, list) else stats
235-
else:
236-
self._stats = [EmbeddingStats()]
237-
238-
self._debug = debug
239332
self._num_proposals: int = 0
240333
self._num_plans: int = 0
241334
self._best_plan: Optional[List[ShardingOption]] = None
242-
self._callbacks: List[
243-
Callable[[List[ShardingOption]], List[ShardingOption]]
244-
] = ([] if callbacks is None else callbacks)
245-
if timeout_seconds is not None:
246-
assert timeout_seconds > 0, "Timeout must be positive"
247-
self._timeout_seconds = timeout_seconds
248335

249336
def collective_plan(
250337
self,

torchrec/inference/inference_legacy/src/BatchingQueue.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,9 @@
1818

1919
#include <ATen/Functions.h> // @manual
2020
#include <ATen/core/Dict.h>
21-
#include <ATen/core/interned_strings.h>
2221
#include <ATen/record_function.h> // @manual
2322
#include <c10/core/Device.h>
2423
#include <c10/core/DeviceType.h>
25-
#include <c10/cuda/CUDAFunctions.h>
2624
#include <c10/cuda/CUDAGuard.h>
2725
#include <c10/cuda/CUDAStream.h>
2826
#include <fmt/format.h>

torchrec/inference/inference_legacy/src/GPUExecutor.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,8 @@
1616
#include <c10/cuda/CUDAGuard.h>
1717
#include <fmt/format.h>
1818
#include <folly/MPMCQueue.h>
19-
#include <folly/ScopeGuard.h>
20-
#include <folly/Synchronized.h>
2119
#include <folly/executors/CPUThreadPoolExecutor.h>
2220
#include <folly/futures/Future.h>
23-
#include <folly/io/IOBuf.h>
2421
#include <folly/io/async/Request.h>
2522
#include <folly/stop_watch.h>
2623
#include <gflags/gflags.h>
@@ -35,7 +32,6 @@
3532
#endif
3633

3734
#include "ATen/cuda/CUDAEvent.h"
38-
#include "torchrec/inference/BatchingQueue.h"
3935
#include "torchrec/inference/ExceptionHandler.h"
4036
#include "torchrec/inference/Observer.h"
4137
#include "torchrec/inference/Types.h"

torchrec/inference/inference_legacy/src/ResultSplit.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88

99
#include "torchrec/inference/ResultSplit.h"
1010

11-
#include <c10/core/ScalarType.h>
1211
#include <folly/Range.h>
1312
#include <folly/container/Enumerate.h>
1413
#include <folly/io/Cursor.h>
1514

16-
#include "ATen/Functions.h"
1715
#include "torchrec/inference/Types.h"
1816

1917
namespace torchrec {

0 commit comments

Comments
 (0)