35
35
from torchrec .distributed .planner import EmbeddingShardingPlanner , Topology
36
36
from torchrec .distributed .sharding_plan import get_default_sharders
37
37
from torchrec .distributed .types import (
38
+ DMPCollectionConfig ,
39
+ DMPCollectionContext ,
38
40
EnumerableShardingSpec ,
39
41
ModuleSharder ,
40
42
ParameterSharding ,
@@ -404,7 +406,6 @@ def _shard_modules_impl(
404
406
module : nn .Module ,
405
407
path : str = "" ,
406
408
) -> nn .Module :
407
-
408
409
# pre-sharded module
409
410
if isinstance (module , ShardedModule ):
410
411
return module
@@ -827,53 +828,150 @@ def __init__(
827
828
data_parallel_wrapper : Optional [DataParallelWrapper ] = None ,
828
829
use_inter_host_allreduce : bool = False ,
829
830
custom_all_reduce : Optional [Callable [[List [torch .Tensor ]], None ]] = None ,
831
+ submodule_configs : Optional [List [DMPCollectionConfig ]] = None ,
830
832
) -> None :
831
833
assert (
832
834
device .type == "cuda" or device .type == "mtia"
833
835
), "DMPCollection only supports CUDA or MTIA"
834
836
self ._device = device
835
837
self ._pg : dist .ProcessGroup = global_pg
836
- self ._plan : ShardingPlan = plan
837
- self ._device_mesh : DeviceMesh = None # pyre-ignore[8]
838
- self ._sharding_pg : dist .ProcessGroup = None # pyre-ignore[8]
839
- self ._replica_pg : dist .ProcessGroup = None # pyre-ignore[8]
840
838
self ._global_rank : int = dist .get_rank (global_pg )
841
839
self ._custom_all_reduce = custom_all_reduce
842
840
843
- self ._device_mesh , self ._sharding_pg , self ._replica_pg = (
844
- self ._create_process_groups (
841
+ if sharders is None :
842
+ sharders = get_default_sharders ()
843
+ self ._sharder_map : Dict [Type [nn .Module ], ModuleSharder [nn .Module ]] = {
844
+ sharder .module_type : sharder for sharder in sharders
845
+ }
846
+
847
+ # the args provided by the users are used for default modules
848
+ # default context is index 0, TODO - if cleaner way to distinguish
849
+ self ._ctxs : List [DMPCollectionContext ] = [
850
+ DMPCollectionContext (
851
+ # default context has module type None
852
+ module = None , # pyre-ignore[6]
853
+ plan = plan ,
854
+ sharding_group_size = sharding_group_size ,
855
+ node_group_size = node_group_size ,
856
+ use_inter_host_allreduce = use_inter_host_allreduce ,
857
+ )
858
+ ]
859
+
860
+ if submodule_configs is not None :
861
+ for submodule_config in submodule_configs :
862
+ self ._ctxs .append (
863
+ DMPCollectionContext (
864
+ module = submodule_config .module ,
865
+ plan = submodule_config .plan ,
866
+ sharding_group_size = submodule_config .sharding_group_size ,
867
+ use_inter_host_allreduce = submodule_config .use_inter_host_allreduce ,
868
+ )
869
+ )
870
+
871
+ # create process groups and remap sharding plans per module context
872
+ for ctx in self ._ctxs :
873
+ (
874
+ device_mesh ,
875
+ sharding_pg ,
876
+ replica_pg ,
877
+ ) = self ._create_process_groups (
845
878
global_rank = self ._global_rank ,
846
879
world_size = world_size ,
847
- local_size = sharding_group_size ,
848
- use_inter_host_allreduce = use_inter_host_allreduce ,
880
+ local_size = ctx . sharding_group_size ,
881
+ use_inter_host_allreduce = ctx . use_inter_host_allreduce ,
849
882
)
883
+
884
+ ctx .device_mesh = device_mesh
885
+ ctx .sharding_pg = sharding_pg
886
+ ctx .replica_pg = replica_pg
887
+
888
+ step = world_size // ctx .sharding_group_size
889
+ self ._remap_sharding_plan (
890
+ plan = ctx .plan ,
891
+ rank = self ._global_rank ,
892
+ step = step ,
893
+ sharding_group_size = ctx .sharding_group_size ,
894
+ use_inter_host_allreduce = ctx .use_inter_host_allreduce ,
895
+ )
896
+
897
+ if ctx .module :
898
+ ctx .sharded_module = self ._sharder_map [ctx .module ].sharded_module_type # pyre-ignore[16]
899
+
900
+ consolidated_plan = copy .deepcopy (self ._ctxs [0 ].plan )
901
+ for ctx in self ._ctxs [1 :]:
902
+ for key , val in ctx .plan .plan .items ():
903
+ consolidated_plan .plan [key ] = copy .deepcopy (val )
904
+
905
+ logger .info (
906
+ "[TorchRec 2D Parallel] Consolidated sharding plan:\n %s" , consolidated_plan
850
907
)
851
908
852
- self . _remap_sharding_plan (
853
- plan = plan ,
854
- rank = self ._global_rank ,
855
- step = world_size // sharding_group_size ,
856
- sharding_group_size = sharding_group_size ,
857
- use_inter_host_allreduce = use_inter_host_allreduce ,
909
+ default_env = ShardingEnv2D (
910
+ global_pg = self . _pg ,
911
+ sharding_pg = self ._ctxs [ 0 ]. sharding_pg ,
912
+ device_mesh = self . _ctxs [ 0 ]. device_mesh ,
913
+ node_group_size = node_group_size ,
914
+ use_inter_host_allreduce = self . _ctxs [ 0 ]. use_inter_host_allreduce ,
858
915
)
859
- super ().__init__ (
916
+
917
+ super ().__init__ ( # type: ignore[misc]
860
918
module ,
861
- ShardingEnv2D (
862
- global_pg = self ._pg ,
863
- sharding_pg = self ._sharding_pg ,
864
- device_mesh = self ._device_mesh ,
865
- node_group_size = node_group_size ,
866
- use_inter_host_allreduce = use_inter_host_allreduce ,
867
- ),
919
+ default_env ,
868
920
device ,
869
- plan ,
921
+ consolidated_plan ,
870
922
sharders ,
871
923
init_data_parallel ,
872
924
init_parameters ,
873
925
data_parallel_wrapper ,
874
926
)
875
- # post DMP init, we group sharded modules for parameter sync
876
- self ._modules_to_sync : List [nn .Module ] = self ._group_sharded_modules ()
927
+
928
+ # post DMP init, we group sharded modules for parameter sync, stored in the context
929
+ self ._group_sharded_modules (self ._ctxs )
930
+
931
+ def _shard_modules_impl (
932
+ self ,
933
+ module : nn .Module ,
934
+ path : str = "" ,
935
+ ) -> nn .Module :
936
+
937
+ # pre-sharded module
938
+ if isinstance (module , ShardedModule ):
939
+ return module
940
+
941
+ # shardable module
942
+ module_sharding_plan = self ._plan .get_plan_for_module (path )
943
+ if module_sharding_plan :
944
+ env = self ._env
945
+ sharder_key = type (module )
946
+
947
+ for ctx in self ._ctxs [1 :]:
948
+ if ctx .module == sharder_key :
949
+ env = ShardingEnv2D (
950
+ global_pg = self ._pg ,
951
+ sharding_pg = ctx .sharding_pg ,
952
+ device_mesh = ctx .device_mesh ,
953
+ node_group_size = ctx .sharding_group_size ,
954
+ use_inter_host_allreduce = ctx .use_inter_host_allreduce ,
955
+ )
956
+ break
957
+
958
+ module = self ._sharder_map [sharder_key ].shard (
959
+ module ,
960
+ module_sharding_plan ,
961
+ env ,
962
+ self .device ,
963
+ path ,
964
+ )
965
+ return module
966
+
967
+ for name , child in module .named_children ():
968
+ child = self ._shard_modules_impl (
969
+ child ,
970
+ path + "." + name if path else name ,
971
+ )
972
+ setattr (module , name , child )
973
+
974
+ return module
877
975
878
976
def sync (self , include_optimizer_state : bool = True ) -> None :
879
977
"""
@@ -888,10 +986,24 @@ def sync(self, include_optimizer_state: bool = True) -> None:
888
986
Args:
889
987
include_optimizer_state (bool): Flag to include optimizer state syncing upon call
890
988
"""
891
- assert self ._replica_pg is not None , "replica_pg is not initialized!"
989
+ # we sync per context to use the right all reduce process group
990
+ for ctx in self ._ctxs :
991
+ self ._sync (
992
+ ctx .replica_pg ,
993
+ ctx .modules_to_sync ,
994
+ include_optimizer_state ,
995
+ )
996
+
997
+ def _sync (
998
+ self ,
999
+ replica_pg : dist .ProcessGroup ,
1000
+ modules_to_sync : List [nn .Module ],
1001
+ include_optimizer_state : bool = True ,
1002
+ ) -> None :
1003
+ assert replica_pg is not None , "replica_pg is not initialized!"
892
1004
all_weights_by_dtype : dict [torch .dtype , List [torch .Tensor ]] = defaultdict (list )
893
1005
894
- for emb_kernel in self . _modules_to_sync :
1006
+ for emb_kernel in modules_to_sync :
895
1007
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
896
1008
for w in emb_kernel .split_embedding_weights ():
897
1009
all_weights_by_dtype [w .dtype ].append (w )
@@ -900,25 +1012,31 @@ def sync(self, include_optimizer_state: bool = True) -> None:
900
1012
if self ._custom_all_reduce is None :
901
1013
opts = dist .AllreduceCoalescedOptions ()
902
1014
opts .reduceOp = dist .ReduceOp .AVG
903
- self ._allreduce_tensors (all_weights_by_dtype , "## 2d_weight_sync ##" , opts )
1015
+ self ._allreduce_tensors (
1016
+ replica_pg , all_weights_by_dtype , "## 2d_weight_sync ##" , opts
1017
+ )
904
1018
905
1019
if include_optimizer_state :
906
1020
optimizer_tensors_by_dtype : Dict [torch .dtype , List [torch .Tensor ]] = (
907
1021
defaultdict (list )
908
1022
)
909
- for emb_kernel in self . _modules_to_sync :
1023
+ for emb_kernel in modules_to_sync :
910
1024
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
911
1025
optimizer_states = emb_kernel .get_optimizer_state ()
912
1026
for state in optimizer_states :
913
1027
opt_tensor = state ["sum" ]
914
1028
optimizer_tensors_by_dtype [opt_tensor .dtype ].append (opt_tensor )
915
1029
if optimizer_tensors_by_dtype :
916
1030
self ._allreduce_tensors (
917
- optimizer_tensors_by_dtype , "## 2d_optimizer_sync ##" , opts
1031
+ replica_pg ,
1032
+ optimizer_tensors_by_dtype ,
1033
+ "## 2d_optimizer_sync ##" ,
1034
+ opts ,
918
1035
)
919
1036
920
1037
def _allreduce_tensors (
921
1038
self ,
1039
+ pg : dist .ProcessGroup ,
922
1040
tensors_dict : Dict [torch .dtype , List [torch .Tensor ]],
923
1041
annotation : str ,
924
1042
opts : Optional [dist .AllreduceCoalescedOptions ] = None ,
@@ -939,7 +1057,7 @@ def _all_reduce(tensors: List[torch.Tensor]) -> None:
939
1057
940
1058
def _all_reduce (tensors : List [torch .Tensor ]) -> None :
941
1059
with record_function (annotation ):
942
- self . _replica_pg .allreduce_coalesced (tensors , opts = opts ).wait ()
1060
+ pg .allreduce_coalesced (tensors , opts = opts ).wait ()
943
1061
944
1062
for tensor_list in tensors_dict .values ():
945
1063
_all_reduce (tensor_list )
@@ -1073,7 +1191,38 @@ def _remap_sharding_plan(
1073
1191
1074
1192
def _group_sharded_modules (
1075
1193
self ,
1194
+ contexts : List [DMPCollectionContext ],
1195
+ ) -> None :
1196
+ # Post init DMP, save the embedding kernels, with respect to contexts
1197
+ for context in contexts [1 :]:
1198
+ context .modules_to_sync = self ._group_sharded_module (context .sharded_module ) # pyre-ignore[6]
1199
+
1200
+ # Group leftover embedding kernels, with respect to default context
1201
+ modules_to_skip : List [nn .Module ] = [c .sharded_module for c in contexts [1 :]] # pyre-ignore[9]
1202
+ sharded_modules : List [nn .Module ] = []
1203
+
1204
+ def _find_sharded_modules (
1205
+ module : nn .Module ,
1206
+ ) -> None :
1207
+ if isinstance (module , SplitTableBatchedEmbeddingBagsCodegen ):
1208
+ sharded_modules .append (module )
1209
+ if not isinstance (
1210
+ module , tuple (modules_to_skip ) # pyre-ignore[6]
1211
+ ) and hasattr (module , "_lookups" ):
1212
+ for lookup in module ._lookups : # pyre-ignore[29]
1213
+ _find_sharded_modules (lookup )
1214
+
1215
+ for _ , child in module .named_children ():
1216
+ _find_sharded_modules (child )
1217
+
1218
+ _find_sharded_modules (self ._dmp_wrapped_module )
1219
+ contexts [0 ].modules_to_sync = sharded_modules
1220
+
1221
+ def _group_sharded_module (
1222
+ self ,
1223
+ sharded_module : nn .Module ,
1076
1224
) -> List [nn .Module ]:
1225
+ # Traverse module and find all sharded module kernels matching the sharded module
1077
1226
# Post init DMP, save the embedding kernels
1078
1227
sharded_modules : List [nn .Module ] = []
1079
1228
@@ -1082,36 +1231,20 @@ def _find_sharded_modules(
1082
1231
) -> None :
1083
1232
if isinstance (module , SplitTableBatchedEmbeddingBagsCodegen ):
1084
1233
sharded_modules .append (module )
1085
- if hasattr (module , "_lookups" ):
1086
- # pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is
1087
- # not a function.
1088
- for lookup in module ._lookups :
1234
+ if isinstance (module , sharded_module ): # pyre-ignore[6]
1235
+ for lookup in module ._lookups : # pyre-ignore[29]
1089
1236
_find_sharded_modules (lookup )
1090
- return
1237
+
1091
1238
for _ , child in module .named_children ():
1092
1239
_find_sharded_modules (child )
1093
1240
1094
1241
_find_sharded_modules (self ._dmp_wrapped_module )
1095
1242
return sharded_modules
1096
1243
1097
- @property
1098
- def sharding_pg (self ) -> dist .ProcessGroup :
1099
- """
1100
- Returns the process group used for this ranks sharding.
1101
- """
1102
- return self ._sharding_pg
1103
-
1104
- @property
1105
- def replica_pg (self ) -> dist .ProcessGroup :
1106
- """
1107
- Returns the process group used for this ranks replication.
1108
- """
1109
- return self ._replica_pg
1110
-
1111
1244
@property
1112
1245
def device_mesh (self ) -> DeviceMesh :
1113
1246
"""
1114
1247
Returns the device mesh used for 2D parallelism.
1115
1248
Contains two dimensions: "replicate" and "shard".
1116
1249
"""
1117
- return self ._device_mesh
1250
+ return self ._ctxs [ 0 ]. device_mesh
0 commit comments