@@ -68,6 +68,22 @@ def compute_kernels(
68
68
return [EmbeddingComputeKernel .DENSE .value ]
69
69
70
70
71
+ class CWSharder (EmbeddingBagCollectionSharder , ModuleSharder [nn .Module ]):
72
+ """
73
+ Column-wise sharder for benchmarking.
74
+ """
75
+
76
+ def sharding_types (self , compute_device_type : str ) -> List [str ]:
77
+ # compute_device_type is required by the interface
78
+ return [ShardingType .COLUMN_WISE .value ]
79
+
80
+ def compute_kernels (
81
+ self , sharding_type : str , compute_device_type : str
82
+ ) -> List [str ]:
83
+ # sharding_type and compute_device_type are required by the interface
84
+ return [EmbeddingComputeKernel .DENSE .value ]
85
+
86
+
71
87
def build_model_and_enumerator (
72
88
world_size : int ,
73
89
num_tables : int ,
@@ -290,20 +306,23 @@ def main() -> None:
290
306
parser .add_argument (
291
307
"--sharder" ,
292
308
type = str ,
293
- choices = ["tw" , "rw" , "both " ],
309
+ choices = ["tw" , "rw" , "cw" , "all " ],
294
310
default = "tw" ,
295
- help = "Sharder type to use: table-wise (tw), row-wise (rw), or both " ,
311
+ help = "Sharder type to use: table-wise (tw), row-wise (rw), column-wise (cw), or all " ,
296
312
)
297
313
298
314
args = parser .parse_args ()
299
315
300
316
# Run benchmark with specified sharder(s)
301
- if args .sharder == "tw" or args .sharder == "both " :
317
+ if args .sharder == "tw" or args .sharder == "all " :
302
318
benchmark_enumerator_comprehensive (TWSharder )
303
319
304
- if args .sharder == "rw" or args .sharder == "both " :
320
+ if args .sharder == "rw" or args .sharder == "all " :
305
321
benchmark_enumerator_comprehensive (RWSharder )
306
322
323
+ if args .sharder == "cw" or args .sharder == "all" :
324
+ benchmark_enumerator_comprehensive (CWSharder )
325
+
307
326
308
327
if __name__ == "__main__" :
309
328
main ()
0 commit comments