Skip to content

Commit 8aa1630

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Enumerator benchmark enhancement
Summary: 1. Rewrote enumerator benchmark from unittest to standalone script with comprehensive matrix testing (world sizes 16-2048 × tables 200-6400). 2. Added sharder type options, millisecond timing, and memory tracking. Differential Revision: D78855814
1 parent fc08e73 commit 8aa1630

File tree

2 files changed

+270
-48
lines changed

2 files changed

+270
-48
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ torchx
1515
tqdm
1616
usort
1717
parameterized
18+
tracemalloc
1819
PyYAML
1920

2021
# for tests

torchrec/distributed/planner/tests/benchmark.py

Lines changed: 269 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,23 @@
77

88
# pyre-strict
99

10-
"""Stress tests for planner to find problematic scaling behavior."""
10+
"""
11+
Comprehensive benchmarks for planner enumerator to analyze performance and scaling behavior.
1112
12-
import time
13-
import unittest
13+
This module provides benchmarks for the EmbeddingEnumerator component, including:
14+
- Performance with varying table counts
15+
- Performance with varying world sizes
16+
- Memory usage tracking
17+
"""
1418

15-
from typing import List, Tuple
19+
import argparse
20+
import gc
21+
import logging
22+
import time
23+
import tracemalloc
24+
from typing import Dict, List, Tuple, Type
1625

1726
from torch import nn
18-
1927
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
2028
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
2129
from torchrec.distributed.planner.constants import BATCH_SIZE
@@ -25,64 +33,277 @@
2533
from torchrec.distributed.types import ModuleSharder, ShardingType
2634
from torchrec.modules.embedding_configs import EmbeddingBagConfig
2735

36+
logger: logging.Logger = logging.getLogger(__name__)
37+
2838

2939
class TWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]):
40+
"""
41+
Table-wise sharder for benchmarking.
42+
"""
43+
3044
def sharding_types(self, compute_device_type: str) -> List[str]:
45+
# compute_device_type is required by the interface
3146
return [ShardingType.TABLE_WISE.value]
3247

3348
def compute_kernels(
3449
self, sharding_type: str, compute_device_type: str
3550
) -> List[str]:
51+
# sharding_type and compute_device_type are required by the interface
3652
return [EmbeddingComputeKernel.DENSE.value]
3753

3854

39-
class TestEnumeratorBenchmark(unittest.TestCase):
40-
@staticmethod
41-
def build(
42-
world_size: int, num_tables: int
43-
) -> Tuple[EmbeddingEnumerator, nn.Module]:
44-
compute_device = "cuda"
45-
topology = Topology(
46-
world_size=world_size, local_world_size=8, compute_device=compute_device
55+
class RWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]):
56+
"""
57+
Row-wise sharder for benchmarking.
58+
"""
59+
60+
def sharding_types(self, compute_device_type: str) -> List[str]:
61+
# compute_device_type is required by the interface
62+
return [ShardingType.ROW_WISE.value]
63+
64+
def compute_kernels(
65+
self, sharding_type: str, compute_device_type: str
66+
) -> List[str]:
67+
# sharding_type and compute_device_type are required by the interface
68+
return [EmbeddingComputeKernel.DENSE.value]
69+
70+
71+
def build_model_and_enumerator(
72+
world_size: int,
73+
num_tables: int,
74+
embedding_dim: int = 128,
75+
local_world_size: int = 8,
76+
compute_device: str = "cuda",
77+
) -> Tuple[EmbeddingEnumerator, nn.Module]:
78+
"""
79+
Build an enumerator and model for benchmarking.
80+
81+
Args:
82+
world_size: Number of devices in the topology
83+
num_tables: Number of embedding tables in the model
84+
embedding_dim: Dimension of each embedding vector
85+
local_world_size: Number of devices per node
86+
compute_device: Device type ("cuda" or "cpu")
87+
88+
Returns:
89+
Tuple of (enumerator, model)
90+
"""
91+
topology = Topology(
92+
world_size=world_size,
93+
local_world_size=local_world_size,
94+
compute_device=compute_device,
95+
)
96+
tables = [
97+
EmbeddingBagConfig(
98+
num_embeddings=100 + i,
99+
embedding_dim=embedding_dim,
100+
name="table_" + str(i),
101+
feature_names=["feature_" + str(i)],
47102
)
48-
tables = [
49-
EmbeddingBagConfig(
50-
num_embeddings=100 + i,
51-
embedding_dim=128,
52-
name="table_" + str(i),
53-
feature_names=["feature_" + str(i)],
54-
)
55-
for i in range(num_tables)
56-
]
57-
model = TestSparseNN(tables=tables, weighted_tables=[])
58-
enumerator = EmbeddingEnumerator(topology=topology, batch_size=BATCH_SIZE)
59-
return enumerator, model
60-
61-
def measure(self, world_size: int, num_tables: int) -> float:
62-
enumerator, model = TestEnumeratorBenchmark.build(world_size, num_tables)
63-
64-
start_time = time.time()
65-
sharding_options = enumerator.enumerate(module=model, sharders=[TWSharder()])
66-
end_time = time.time()
67-
68-
self.assertEqual(len(sharding_options), num_tables)
69-
return end_time - start_time
70-
71-
def test_benchmark(self) -> None:
72-
tests = [(2048, d) for d in [100, 200, 400, 800, 1600, 3200, 6400]]
73-
print("\nEnumerator benchmark:")
74-
for world_size, num_tables in tests:
75-
t = self.measure(world_size, num_tables)
76-
print(
77-
f"world_size={world_size:8} num_tables={num_tables:8} enumerate={t:4.2f}s"
78-
)
103+
for i in range(num_tables)
104+
]
105+
model = TestSparseNN(tables=tables, weighted_tables=[])
106+
enumerator = EmbeddingEnumerator(topology=topology, batch_size=BATCH_SIZE)
107+
return enumerator, model
108+
109+
110+
def measure_memory_and_time(
111+
world_size: int,
112+
num_tables: int,
113+
embedding_dim: int = 128,
114+
sharder_class: Type[ModuleSharder[nn.Module]] = TWSharder,
115+
) -> Dict[str, float]:
116+
"""
117+
Measure both time and memory usage for the enumerate operation.
118+
119+
Args:
120+
world_size: Number of devices in the topology
121+
num_tables: Number of embedding tables in the model
122+
embedding_dim: Dimension of each embedding vector
123+
sharder_class: The sharder class to use
124+
125+
Returns:
126+
Dictionary with time and memory metrics
127+
"""
128+
# Force garbage collection before measurement
129+
gc.collect()
130+
131+
# Build model and enumerator
132+
enumerator, model = build_model_and_enumerator(
133+
world_size=world_size,
134+
num_tables=num_tables,
135+
embedding_dim=embedding_dim,
136+
)
137+
138+
# Force garbage collection again after model building
139+
gc.collect()
140+
141+
# Start memory tracking
142+
tracemalloc.start()
143+
144+
# Measure time
145+
start_time = time.time()
146+
sharding_options = enumerator.enumerate(module=model, sharders=[sharder_class()])
147+
end_time = time.time()
148+
elapsed_time = end_time - start_time
149+
150+
# Get memory usage
151+
current, peak = tracemalloc.get_traced_memory()
152+
# Convert to MB
153+
peak_mb = peak / (1024 * 1024)
154+
155+
# Stop memory tracking
156+
tracemalloc.stop()
157+
158+
# Verify the result
159+
assert len(sharding_options) == num_tables, "Unexpected number of sharding options"
160+
161+
# Convert time to milliseconds
162+
elapsed_time_ms = elapsed_time * 1000
163+
164+
return {
165+
"time_ms": elapsed_time_ms,
166+
"memory_mb": peak_mb,
167+
"options_count": len(sharding_options),
168+
}
169+
170+
171+
def benchmark_enumerator_comprehensive(
172+
sharder_class: Type[ModuleSharder[nn.Module]] = TWSharder,
173+
) -> None:
174+
"""
175+
Comprehensive benchmark testing all combinations of world sizes and table counts.
176+
Tests world sizes from 16 to 2048 and table counts from 200 to 6400.
177+
178+
Args:
179+
sharder_class: The sharder class to use for benchmarking
180+
"""
181+
# Define the ranges for world sizes and table counts
182+
world_sizes = [16, 32, 64, 128, 256, 512, 1024, 2048]
183+
table_counts = [200, 400, 800, 1600, 3200, 6400]
184+
185+
# Create a matrix to store results
186+
results = {}
187+
188+
sharder_name = sharder_class.__name__
189+
logger.info(f"Running comprehensive enumerator benchmark with {sharder_name}...")
190+
logger.info(
191+
f"Testing {len(world_sizes)} world sizes × {len(table_counts)} table counts = {len(world_sizes) * len(table_counts)} combinations"
192+
)
193+
194+
# Track progress
195+
total_combinations = len(world_sizes) * len(table_counts)
196+
completed = 0
197+
198+
# Run benchmarks for all combinations
199+
for world_size in world_sizes:
200+
logger.info(f"Starting benchmarks for world_size={world_size}...")
201+
results[world_size] = {}
202+
world_size_start_time = time.time()
203+
204+
# Run all table counts for this world size
205+
for num_tables in table_counts:
206+
try:
207+
metrics = measure_memory_and_time(
208+
world_size=world_size,
209+
num_tables=num_tables,
210+
sharder_class=sharder_class,
211+
)
212+
results[world_size][num_tables] = metrics
213+
except Exception as e:
214+
results[world_size][num_tables] = {
215+
"time_ms": -1,
216+
"memory_mb": -1,
217+
"options_count": -1,
218+
"error": str(e),
219+
}
220+
221+
completed += 1
222+
223+
# Log completion of all table counts for this world size
224+
world_size_elapsed = time.time() - world_size_start_time
225+
logger.info(
226+
f"Completed world_size={world_size} ({len(table_counts)} table counts) "
227+
f"in {world_size_elapsed:.2f}s ({completed}/{total_combinations} combinations done)"
228+
)
229+
230+
# Print intermediate results for this world size
231+
logger.info(f"Results for world_size={world_size}:")
232+
logger.info(f"{'Table Count':<12} {'Time (ms)':<10} {'Memory (MB)':<12}")
233+
logger.info("-" * 35)
234+
for num_tables in table_counts:
235+
if results[world_size][num_tables].get("error"):
236+
logger.info(f"{num_tables:<12} {'ERROR':<10} {'ERROR':<12}")
237+
else:
238+
logger.info(
239+
f"{num_tables:<12} "
240+
f"{results[world_size][num_tables]['time_ms']:<10.2f} "
241+
f"{results[world_size][num_tables]['memory_mb']:<12.2f}"
242+
)
243+
244+
# Print summary table after all tests are complete
245+
logger.info(f"\nComprehensive Enumerator Benchmark with {sharder_name} - Results:")
246+
247+
# Print header row with table counts
248+
header = "World Size"
249+
for num_tables in table_counts:
250+
header += f" | {num_tables:>8}"
251+
logger.info(header)
252+
logger.info("-" * len(header))
253+
254+
# Print time results
255+
logger.info("\nTime (milliseconds):")
256+
for world_size in world_sizes:
257+
row = f"{world_size:>10}"
258+
for num_tables in table_counts:
259+
if results[world_size][num_tables].get("error"):
260+
row += f" | {'ERROR':>8}"
261+
else:
262+
row += f" | {results[world_size][num_tables]['time_ms']:>8.2f}"
263+
logger.info(row)
264+
265+
# Print memory results
266+
logger.info("\nMemory (MB):")
267+
for world_size in world_sizes:
268+
row = f"{world_size:>10}"
269+
for num_tables in table_counts:
270+
if results[world_size][num_tables].get("error"):
271+
row += f" | {'ERROR':>8}"
272+
else:
273+
row += f" | {results[world_size][num_tables]['memory_mb']:>8.2f}"
274+
logger.info(row)
79275

80276

81277
def main() -> None:
82-
unittest.main()
278+
"""
279+
Main entry point for the benchmark script.
280+
281+
Provides a command-line interface to run specific benchmarks.
282+
"""
283+
# Configure logging
284+
logging.basicConfig(
285+
level=logging.INFO,
286+
format="%(asctime)s [%(levelname)s] %(message)s",
287+
datefmt="%Y-%m-%d %H:%M:%S",
288+
)
289+
parser = argparse.ArgumentParser(description="Run planner enumerator benchmarks")
290+
parser.add_argument(
291+
"--sharder",
292+
type=str,
293+
choices=["tw", "rw", "both"],
294+
default="tw",
295+
help="Sharder type to use: table-wise (tw), row-wise (rw), or both",
296+
)
297+
298+
args = parser.parse_args()
299+
300+
# Run benchmark with specified sharder(s)
301+
if args.sharder == "tw" or args.sharder == "both":
302+
benchmark_enumerator_comprehensive(TWSharder)
303+
304+
if args.sharder == "rw" or args.sharder == "both":
305+
benchmark_enumerator_comprehensive(RWSharder)
83306

84307

85-
# This is structured as a unitttest like file so you can use its built-in command
86-
# line argument parsing to control which benchmarks to run, e.g. "-k Enumerator"
87308
if __name__ == "__main__":
88-
main() # pragma: no cover
309+
main()

0 commit comments

Comments
 (0)