@@ -76,6 +76,22 @@ def compute_kernels(
76
76
return [EmbeddingComputeKernel .DENSE .value ]
77
77
78
78
79
+ class CWSharder (EmbeddingBagCollectionSharder , ModuleSharder [nn .Module ]):
80
+ """
81
+ Column-wise sharder for benchmarking.
82
+ """
83
+
84
+ def sharding_types (self , compute_device_type : str ) -> List [str ]:
85
+ # compute_device_type is required by the interface
86
+ return [ShardingType .COLUMN_WISE .value ]
87
+
88
+ def compute_kernels (
89
+ self , sharding_type : str , compute_device_type : str
90
+ ) -> List [str ]:
91
+ # sharding_type and compute_device_type are required by the interface
92
+ return [EmbeddingComputeKernel .DENSE .value ]
93
+
94
+
79
95
def build_model_and_enumerator (
80
96
world_size : int ,
81
97
num_tables : int ,
@@ -292,21 +308,24 @@ def main() -> None:
292
308
parser .add_argument (
293
309
"--sharder" ,
294
310
type = str ,
295
- choices = ["tw" , "rw" , "both " ],
311
+ choices = ["tw" , "rw" , "cw" , "all " ],
296
312
default = "tw" ,
297
- help = "Sharder type to use: table-wise (tw), row-wise (rw), or both " ,
313
+ help = "Sharder type to use: table-wise (tw), row-wise (rw), column-wise (cw), or all " ,
298
314
)
299
315
logger .warning ("Running planner enumerator benchmarks..." )
300
316
301
317
args = parser .parse_args ()
302
318
303
319
# Run benchmark with specified sharder(s)
304
- if args .sharder == "tw" or args .sharder == "both " :
320
+ if args .sharder == "tw" or args .sharder == "all " :
305
321
benchmark_enumerator_comprehensive (TWSharder )
306
322
307
- if args .sharder == "rw" or args .sharder == "both " :
323
+ if args .sharder == "rw" or args .sharder == "all " :
308
324
benchmark_enumerator_comprehensive (RWSharder )
309
325
326
+ if args .sharder == "cw" or args .sharder == "all" :
327
+ benchmark_enumerator_comprehensive (CWSharder )
328
+
310
329
311
330
if __name__ == "__main__" :
312
331
main ()
0 commit comments