Skip to content

Commit e91d9aa

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Added support for column_wise benchmark in enumerator benchmark (pytorch#3230)
Summary: Added option to include column_wise enumerator benchmark Reviewed By: SSYernar Differential Revision: D78856316
1 parent ac4407e commit e91d9aa

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

torchrec/distributed/planner/tests/benchmark.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,22 @@ def compute_kernels(
6868
return [EmbeddingComputeKernel.DENSE.value]
6969

7070

71+
class CWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]):
72+
"""
73+
Column-wise sharder for benchmarking.
74+
"""
75+
76+
def sharding_types(self, compute_device_type: str) -> List[str]:
77+
# compute_device_type is required by the interface
78+
return [ShardingType.COLUMN_WISE.value]
79+
80+
def compute_kernels(
81+
self, sharding_type: str, compute_device_type: str
82+
) -> List[str]:
83+
# sharding_type and compute_device_type are required by the interface
84+
return [EmbeddingComputeKernel.DENSE.value]
85+
86+
7187
def build_model_and_enumerator(
7288
world_size: int,
7389
num_tables: int,
@@ -290,20 +306,23 @@ def main() -> None:
290306
parser.add_argument(
291307
"--sharder",
292308
type=str,
293-
choices=["tw", "rw", "both"],
309+
choices=["tw", "rw", "cw", "all"],
294310
default="tw",
295-
help="Sharder type to use: table-wise (tw), row-wise (rw), or both",
311+
help="Sharder type to use: table-wise (tw), row-wise (rw), column-wise (cw), or all",
296312
)
297313

298314
args = parser.parse_args()
299315

300316
# Run benchmark with specified sharder(s)
301-
if args.sharder == "tw" or args.sharder == "both":
317+
if args.sharder == "tw" or args.sharder == "all":
302318
benchmark_enumerator_comprehensive(TWSharder)
303319

304-
if args.sharder == "rw" or args.sharder == "both":
320+
if args.sharder == "rw" or args.sharder == "all":
305321
benchmark_enumerator_comprehensive(RWSharder)
306322

323+
if args.sharder == "cw" or args.sharder == "all":
324+
benchmark_enumerator_comprehensive(CWSharder)
325+
307326

308327
if __name__ == "__main__":
309328
main()

0 commit comments

Comments
 (0)