Skip to content

Commit d599766

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
Dynamic 2D sparse parallel (pytorch#3177)
Summary: Pull Request resolved: pytorch#3177 We add the ability to set 2D parallel configuration per module (coined as Dynamic 2D parallel). This means an EBC and EC can be sharded differently on the data parallel dimension. We can have 4 replicas per EBC shard and 2 replicas per EC shard. The previous setting requires all modules to have the same replication factor. To do this, we introduce a lightweight dataclass that is used to provide per module configurations, allowing very granular control should it be required by the user: ```python class DMPCollectionConfig: module: nn.Module # this is expected to be unsharded module plan: "ShardingPlan" # sub-tree-specific sharding plan sharding_group_size: int use_inter_host_allreduce: bool = False ``` The dataclass is used to provide the context for each module we are sharding. And, if configured, create separate process groups and sync logic for each of these modules. Usage is as follows, suppose we want to use a different 2D configuration for EmbeddingCollection: ```python # create plan for model tables over Y world size # create plan for EmbeddingCollection tables over X world size ec_config = DMPCollectionConfig(EmbeddingCollection, embedding_collection_plan, sharding_group_size) model = DMPCollection( # pass in defaults args submodule_configs = [ec_config] ) ``` Future work includes: - making it easier for users to create seperate sharding plans per module - per table 2D Reviewed By: liangbeixu Differential Revision: D76774334 fbshipit-source-id: 27c7e0bc806d8227d784461a197cd8f1c7f6adfc
1 parent d95f247 commit d599766

File tree

11 files changed

+671
-63
lines changed

11 files changed

+671
-63
lines changed

torchrec/distributed/embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1617,3 +1617,7 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
16171617
@property
16181618
def module_type(self) -> Type[EmbeddingCollection]:
16191619
return EmbeddingCollection
1620+
1621+
@property
1622+
def sharded_module_type(self) -> Type[ShardedEmbeddingCollection]:
1623+
return ShardedEmbeddingCollection

torchrec/distributed/embeddingbag.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,6 +1853,10 @@ def reshard(
18531853
def module_type(self) -> Type[EmbeddingBagCollection]:
18541854
return EmbeddingBagCollection
18551855

1856+
@property
1857+
def sharded_module_type(self) -> Type[ShardedEmbeddingBagCollection]:
1858+
return ShardedEmbeddingBagCollection
1859+
18561860

18571861
class EmbeddingAwaitable(LazyAwaitable[torch.Tensor]):
18581862
def __init__(

torchrec/distributed/fp_embeddingbag.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,12 @@ def shardable_parameters(
215215
def module_type(self) -> Type[FeatureProcessedEmbeddingBagCollection]:
216216
return FeatureProcessedEmbeddingBagCollection
217217

218+
@property
219+
def sharded_module_type(
220+
self,
221+
) -> Type[ShardedFeatureProcessedEmbeddingBagCollection]:
222+
return ShardedFeatureProcessedEmbeddingBagCollection
223+
218224
def sharding_types(self, compute_device_type: str) -> List[str]:
219225
if compute_device_type in {"mtia"}:
220226
return [ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value]

torchrec/distributed/itep_embeddingbag.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ def shardable_parameters(
321321
def module_type(self) -> Type[ITEPEmbeddingBagCollection]:
322322
return ITEPEmbeddingBagCollection
323323

324+
@property
325+
def sharded_module_type(self) -> Type[ShardedITEPEmbeddingBagCollection]:
326+
return ShardedITEPEmbeddingBagCollection
327+
324328
def sharding_types(self, compute_device_type: str) -> List[str]:
325329
types = list(SHARDING_TYPE_TO_GROUP.keys())
326330
return types

torchrec/distributed/mc_embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,7 @@ def shard(
144144
@property
145145
def module_type(self) -> Type[ManagedCollisionEmbeddingCollection]:
146146
return ManagedCollisionEmbeddingCollection
147+
148+
@property
149+
def sharded_module_type(self) -> Type[ShardedManagedCollisionEmbeddingCollection]:
150+
return ShardedManagedCollisionEmbeddingCollection

torchrec/distributed/model_parallel.py

Lines changed: 185 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
3636
from torchrec.distributed.sharding_plan import get_default_sharders
3737
from torchrec.distributed.types import (
38+
DMPCollectionConfig,
39+
DMPCollectionContext,
3840
EnumerableShardingSpec,
3941
ModuleSharder,
4042
ParameterSharding,
@@ -404,7 +406,6 @@ def _shard_modules_impl(
404406
module: nn.Module,
405407
path: str = "",
406408
) -> nn.Module:
407-
408409
# pre-sharded module
409410
if isinstance(module, ShardedModule):
410411
return module
@@ -827,53 +828,150 @@ def __init__(
827828
data_parallel_wrapper: Optional[DataParallelWrapper] = None,
828829
use_inter_host_allreduce: bool = False,
829830
custom_all_reduce: Optional[Callable[[List[torch.Tensor]], None]] = None,
831+
submodule_configs: Optional[List[DMPCollectionConfig]] = None,
830832
) -> None:
831833
assert (
832834
device.type == "cuda" or device.type == "mtia"
833835
), "DMPCollection only supports CUDA or MTIA"
834836
self._device = device
835837
self._pg: dist.ProcessGroup = global_pg
836-
self._plan: ShardingPlan = plan
837-
self._device_mesh: DeviceMesh = None # pyre-ignore[8]
838-
self._sharding_pg: dist.ProcessGroup = None # pyre-ignore[8]
839-
self._replica_pg: dist.ProcessGroup = None # pyre-ignore[8]
840838
self._global_rank: int = dist.get_rank(global_pg)
841839
self._custom_all_reduce = custom_all_reduce
842840

843-
self._device_mesh, self._sharding_pg, self._replica_pg = (
844-
self._create_process_groups(
841+
if sharders is None:
842+
sharders = get_default_sharders()
843+
self._sharder_map: Dict[Type[nn.Module], ModuleSharder[nn.Module]] = {
844+
sharder.module_type: sharder for sharder in sharders
845+
}
846+
847+
# the args provided by the users are used for default modules
848+
# default context is index 0, TODO - if cleaner way to distinguish
849+
self._ctxs: List[DMPCollectionContext] = [
850+
DMPCollectionContext(
851+
# default context has module type None
852+
module=None, # pyre-ignore[6]
853+
plan=plan,
854+
sharding_group_size=sharding_group_size,
855+
node_group_size=node_group_size,
856+
use_inter_host_allreduce=use_inter_host_allreduce,
857+
)
858+
]
859+
860+
if submodule_configs is not None:
861+
for submodule_config in submodule_configs:
862+
self._ctxs.append(
863+
DMPCollectionContext(
864+
module=submodule_config.module,
865+
plan=submodule_config.plan,
866+
sharding_group_size=submodule_config.sharding_group_size,
867+
use_inter_host_allreduce=submodule_config.use_inter_host_allreduce,
868+
)
869+
)
870+
871+
# create process groups and remap sharding plans per module context
872+
for ctx in self._ctxs:
873+
(
874+
device_mesh,
875+
sharding_pg,
876+
replica_pg,
877+
) = self._create_process_groups(
845878
global_rank=self._global_rank,
846879
world_size=world_size,
847-
local_size=sharding_group_size,
848-
use_inter_host_allreduce=use_inter_host_allreduce,
880+
local_size=ctx.sharding_group_size,
881+
use_inter_host_allreduce=ctx.use_inter_host_allreduce,
849882
)
883+
884+
ctx.device_mesh = device_mesh
885+
ctx.sharding_pg = sharding_pg
886+
ctx.replica_pg = replica_pg
887+
888+
step = world_size // ctx.sharding_group_size
889+
self._remap_sharding_plan(
890+
plan=ctx.plan,
891+
rank=self._global_rank,
892+
step=step,
893+
sharding_group_size=ctx.sharding_group_size,
894+
use_inter_host_allreduce=ctx.use_inter_host_allreduce,
895+
)
896+
897+
if ctx.module:
898+
ctx.sharded_module = self._sharder_map[ctx.module].sharded_module_type # pyre-ignore[16]
899+
900+
consolidated_plan = copy.deepcopy(self._ctxs[0].plan)
901+
for ctx in self._ctxs[1:]:
902+
for key, val in ctx.plan.plan.items():
903+
consolidated_plan.plan[key] = copy.deepcopy(val)
904+
905+
logger.info(
906+
"[TorchRec 2D Parallel] Consolidated sharding plan:\n%s", consolidated_plan
850907
)
851908

852-
self._remap_sharding_plan(
853-
plan=plan,
854-
rank=self._global_rank,
855-
step=world_size // sharding_group_size,
856-
sharding_group_size=sharding_group_size,
857-
use_inter_host_allreduce=use_inter_host_allreduce,
909+
default_env = ShardingEnv2D(
910+
global_pg=self._pg,
911+
sharding_pg=self._ctxs[0].sharding_pg,
912+
device_mesh=self._ctxs[0].device_mesh,
913+
node_group_size=node_group_size,
914+
use_inter_host_allreduce=self._ctxs[0].use_inter_host_allreduce,
858915
)
859-
super().__init__(
916+
917+
super().__init__( # type: ignore[misc]
860918
module,
861-
ShardingEnv2D(
862-
global_pg=self._pg,
863-
sharding_pg=self._sharding_pg,
864-
device_mesh=self._device_mesh,
865-
node_group_size=node_group_size,
866-
use_inter_host_allreduce=use_inter_host_allreduce,
867-
),
919+
default_env,
868920
device,
869-
plan,
921+
consolidated_plan,
870922
sharders,
871923
init_data_parallel,
872924
init_parameters,
873925
data_parallel_wrapper,
874926
)
875-
# post DMP init, we group sharded modules for parameter sync
876-
self._modules_to_sync: List[nn.Module] = self._group_sharded_modules()
927+
928+
# post DMP init, we group sharded modules for parameter sync, stored in the context
929+
self._group_sharded_modules(self._ctxs)
930+
931+
def _shard_modules_impl(
932+
self,
933+
module: nn.Module,
934+
path: str = "",
935+
) -> nn.Module:
936+
937+
# pre-sharded module
938+
if isinstance(module, ShardedModule):
939+
return module
940+
941+
# shardable module
942+
module_sharding_plan = self._plan.get_plan_for_module(path)
943+
if module_sharding_plan:
944+
env = self._env
945+
sharder_key = type(module)
946+
947+
for ctx in self._ctxs[1:]:
948+
if ctx.module == sharder_key:
949+
env = ShardingEnv2D(
950+
global_pg=self._pg,
951+
sharding_pg=ctx.sharding_pg,
952+
device_mesh=ctx.device_mesh,
953+
node_group_size=ctx.sharding_group_size,
954+
use_inter_host_allreduce=ctx.use_inter_host_allreduce,
955+
)
956+
break
957+
958+
module = self._sharder_map[sharder_key].shard(
959+
module,
960+
module_sharding_plan,
961+
env,
962+
self.device,
963+
path,
964+
)
965+
return module
966+
967+
for name, child in module.named_children():
968+
child = self._shard_modules_impl(
969+
child,
970+
path + "." + name if path else name,
971+
)
972+
setattr(module, name, child)
973+
974+
return module
877975

878976
def sync(self, include_optimizer_state: bool = True) -> None:
879977
"""
@@ -888,10 +986,24 @@ def sync(self, include_optimizer_state: bool = True) -> None:
888986
Args:
889987
include_optimizer_state (bool): Flag to include optimizer state syncing upon call
890988
"""
891-
assert self._replica_pg is not None, "replica_pg is not initialized!"
989+
# we sync per context to use the right all reduce process group
990+
for ctx in self._ctxs:
991+
self._sync(
992+
ctx.replica_pg,
993+
ctx.modules_to_sync,
994+
include_optimizer_state,
995+
)
996+
997+
def _sync(
998+
self,
999+
replica_pg: dist.ProcessGroup,
1000+
modules_to_sync: List[nn.Module],
1001+
include_optimizer_state: bool = True,
1002+
) -> None:
1003+
assert replica_pg is not None, "replica_pg is not initialized!"
8921004
all_weights_by_dtype: dict[torch.dtype, List[torch.Tensor]] = defaultdict(list)
8931005

894-
for emb_kernel in self._modules_to_sync:
1006+
for emb_kernel in modules_to_sync:
8951007
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
8961008
for w in emb_kernel.split_embedding_weights():
8971009
all_weights_by_dtype[w.dtype].append(w)
@@ -900,25 +1012,31 @@ def sync(self, include_optimizer_state: bool = True) -> None:
9001012
if self._custom_all_reduce is None:
9011013
opts = dist.AllreduceCoalescedOptions()
9021014
opts.reduceOp = dist.ReduceOp.AVG
903-
self._allreduce_tensors(all_weights_by_dtype, "## 2d_weight_sync ##", opts)
1015+
self._allreduce_tensors(
1016+
replica_pg, all_weights_by_dtype, "## 2d_weight_sync ##", opts
1017+
)
9041018

9051019
if include_optimizer_state:
9061020
optimizer_tensors_by_dtype: Dict[torch.dtype, List[torch.Tensor]] = (
9071021
defaultdict(list)
9081022
)
909-
for emb_kernel in self._modules_to_sync:
1023+
for emb_kernel in modules_to_sync:
9101024
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
9111025
optimizer_states = emb_kernel.get_optimizer_state()
9121026
for state in optimizer_states:
9131027
opt_tensor = state["sum"]
9141028
optimizer_tensors_by_dtype[opt_tensor.dtype].append(opt_tensor)
9151029
if optimizer_tensors_by_dtype:
9161030
self._allreduce_tensors(
917-
optimizer_tensors_by_dtype, "## 2d_optimizer_sync ##", opts
1031+
replica_pg,
1032+
optimizer_tensors_by_dtype,
1033+
"## 2d_optimizer_sync ##",
1034+
opts,
9181035
)
9191036

9201037
def _allreduce_tensors(
9211038
self,
1039+
pg: dist.ProcessGroup,
9221040
tensors_dict: Dict[torch.dtype, List[torch.Tensor]],
9231041
annotation: str,
9241042
opts: Optional[dist.AllreduceCoalescedOptions] = None,
@@ -939,7 +1057,7 @@ def _all_reduce(tensors: List[torch.Tensor]) -> None:
9391057

9401058
def _all_reduce(tensors: List[torch.Tensor]) -> None:
9411059
with record_function(annotation):
942-
self._replica_pg.allreduce_coalesced(tensors, opts=opts).wait()
1060+
pg.allreduce_coalesced(tensors, opts=opts).wait()
9431061

9441062
for tensor_list in tensors_dict.values():
9451063
_all_reduce(tensor_list)
@@ -1073,7 +1191,38 @@ def _remap_sharding_plan(
10731191

10741192
def _group_sharded_modules(
10751193
self,
1194+
contexts: List[DMPCollectionContext],
1195+
) -> None:
1196+
# Post init DMP, save the embedding kernels, with respect to contexts
1197+
for context in contexts[1:]:
1198+
context.modules_to_sync = self._group_sharded_module(context.sharded_module) # pyre-ignore[6]
1199+
1200+
# Group leftover embedding kernels, with respect to default context
1201+
modules_to_skip: List[nn.Module] = [c.sharded_module for c in contexts[1:]] # pyre-ignore[9]
1202+
sharded_modules: List[nn.Module] = []
1203+
1204+
def _find_sharded_modules(
1205+
module: nn.Module,
1206+
) -> None:
1207+
if isinstance(module, SplitTableBatchedEmbeddingBagsCodegen):
1208+
sharded_modules.append(module)
1209+
if not isinstance(
1210+
module, tuple(modules_to_skip) # pyre-ignore[6]
1211+
) and hasattr(module, "_lookups"):
1212+
for lookup in module._lookups: # pyre-ignore[29]
1213+
_find_sharded_modules(lookup)
1214+
1215+
for _, child in module.named_children():
1216+
_find_sharded_modules(child)
1217+
1218+
_find_sharded_modules(self._dmp_wrapped_module)
1219+
contexts[0].modules_to_sync = sharded_modules
1220+
1221+
def _group_sharded_module(
1222+
self,
1223+
sharded_module: nn.Module,
10761224
) -> List[nn.Module]:
1225+
# Traverse module and find all sharded module kernels matching the sharded module
10771226
# Post init DMP, save the embedding kernels
10781227
sharded_modules: List[nn.Module] = []
10791228

@@ -1082,36 +1231,20 @@ def _find_sharded_modules(
10821231
) -> None:
10831232
if isinstance(module, SplitTableBatchedEmbeddingBagsCodegen):
10841233
sharded_modules.append(module)
1085-
if hasattr(module, "_lookups"):
1086-
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is
1087-
# not a function.
1088-
for lookup in module._lookups:
1234+
if isinstance(module, sharded_module): # pyre-ignore[6]
1235+
for lookup in module._lookups: # pyre-ignore[29]
10891236
_find_sharded_modules(lookup)
1090-
return
1237+
10911238
for _, child in module.named_children():
10921239
_find_sharded_modules(child)
10931240

10941241
_find_sharded_modules(self._dmp_wrapped_module)
10951242
return sharded_modules
10961243

1097-
@property
1098-
def sharding_pg(self) -> dist.ProcessGroup:
1099-
"""
1100-
Returns the process group used for this ranks sharding.
1101-
"""
1102-
return self._sharding_pg
1103-
1104-
@property
1105-
def replica_pg(self) -> dist.ProcessGroup:
1106-
"""
1107-
Returns the process group used for this ranks replication.
1108-
"""
1109-
return self._replica_pg
1110-
11111244
@property
11121245
def device_mesh(self) -> DeviceMesh:
11131246
"""
11141247
Returns the device mesh used for 2D parallelism.
11151248
Contains two dimensions: "replicate" and "shard".
11161249
"""
1117-
return self._device_mesh
1250+
return self._ctxs[0].device_mesh

0 commit comments

Comments
 (0)