Skip to content

Commit a46e819

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Added support for column_wise benchmark in enumerator benchmark (#3230)
Summary: Added option to include column_wise enumerator benchmark Reviewed By: SSYernar Differential Revision: D78856316
1 parent 780a43d commit a46e819

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

torchrec/distributed/planner/tests/benchmark.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,22 @@ def compute_kernels(
7676
return [EmbeddingComputeKernel.DENSE.value]
7777

7878

79+
class CWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]):
80+
"""
81+
Column-wise sharder for benchmarking.
82+
"""
83+
84+
def sharding_types(self, compute_device_type: str) -> List[str]:
85+
# compute_device_type is required by the interface
86+
return [ShardingType.COLUMN_WISE.value]
87+
88+
def compute_kernels(
89+
self, sharding_type: str, compute_device_type: str
90+
) -> List[str]:
91+
# sharding_type and compute_device_type are required by the interface
92+
return [EmbeddingComputeKernel.DENSE.value]
93+
94+
7995
def build_model_and_enumerator(
8096
world_size: int,
8197
num_tables: int,
@@ -292,21 +308,24 @@ def main() -> None:
292308
parser.add_argument(
293309
"--sharder",
294310
type=str,
295-
choices=["tw", "rw", "both"],
311+
choices=["tw", "rw", "cw", "all"],
296312
default="tw",
297-
help="Sharder type to use: table-wise (tw), row-wise (rw), or both",
313+
help="Sharder type to use: table-wise (tw), row-wise (rw), column-wise (cw), or all",
298314
)
299315
logger.warning("Running planner enumerator benchmarks...")
300316

301317
args = parser.parse_args()
302318

303319
# Run benchmark with specified sharder(s)
304-
if args.sharder == "tw" or args.sharder == "both":
320+
if args.sharder == "tw" or args.sharder == "all":
305321
benchmark_enumerator_comprehensive(TWSharder)
306322

307-
if args.sharder == "rw" or args.sharder == "both":
323+
if args.sharder == "rw" or args.sharder == "all":
308324
benchmark_enumerator_comprehensive(RWSharder)
309325

326+
if args.sharder == "cw" or args.sharder == "all":
327+
benchmark_enumerator_comprehensive(CWSharder)
328+
310329

311330
if __name__ == "__main__":
312331
main()

0 commit comments

Comments
 (0)