Skip to content

Commit 9f45df7

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Enumerator benchmark enhancement (#3229)
Summary: 1. Rewrote enumerator benchmark from unittest to standalone script with comprehensive matrix testing (world sizes 16-2048 × tables 200-6400). 2. Added sharder type options, millisecond timing, and memory tracking. Reviewed By: SSYernar Differential Revision: D78855814
1 parent 89e7771 commit 9f45df7

File tree

1 file changed

+272
-48
lines changed

1 file changed

+272
-48
lines changed

torchrec/distributed/planner/tests/benchmark.py

Lines changed: 272 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,23 @@
77

88
# pyre-strict
99

10-
"""Stress tests for planner to find problematic scaling behavior."""
10+
"""
11+
Comprehensive benchmarks for planner enumerator to analyze performance and scaling behavior.
1112
12-
import time
13-
import unittest
13+
This module provides benchmarks for the EmbeddingEnumerator component, including:
14+
- Performance with varying table counts
15+
- Performance with varying world sizes
16+
- Memory usage tracking
17+
"""
1418

15-
from typing import List, Tuple
19+
import argparse
20+
import gc
21+
import logging
22+
import resource
23+
import time
24+
from typing import Dict, List, Tuple, Type
1625

1726
from torch import nn
18-
1927
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
2028
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
2129
from torchrec.distributed.planner.constants import BATCH_SIZE
@@ -25,64 +33,280 @@
2533
from torchrec.distributed.types import ModuleSharder, ShardingType
2634
from torchrec.modules.embedding_configs import EmbeddingBagConfig
2735

36+
# Configure logging to ensure visibility
37+
logging.basicConfig(
38+
level=logging.INFO,
39+
format="%(asctime)s [%(levelname)s] %(message)s",
40+
datefmt="%Y-%m-%d %H:%M:%S",
41+
)
42+
logger: logging.Logger = logging.getLogger(__name__)
43+
# Force the logger to use the configured level
44+
logger.setLevel(logging.INFO)
45+
2846

2947
class TWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]):
48+
"""
49+
Table-wise sharder for benchmarking.
50+
"""
51+
3052
def sharding_types(self, compute_device_type: str) -> List[str]:
53+
# compute_device_type is required by the interface
3154
return [ShardingType.TABLE_WISE.value]
3255

3356
def compute_kernels(
3457
self, sharding_type: str, compute_device_type: str
3558
) -> List[str]:
59+
# sharding_type and compute_device_type are required by the interface
60+
return [EmbeddingComputeKernel.DENSE.value]
61+
62+
63+
class RWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]):
64+
"""
65+
Row-wise sharder for benchmarking.
66+
"""
67+
68+
def sharding_types(self, compute_device_type: str) -> List[str]:
69+
# compute_device_type is required by the interface
70+
return [ShardingType.ROW_WISE.value]
71+
72+
def compute_kernels(
73+
self, sharding_type: str, compute_device_type: str
74+
) -> List[str]:
75+
# sharding_type and compute_device_type are required by the interface
3676
return [EmbeddingComputeKernel.DENSE.value]
3777

3878

39-
class TestEnumeratorBenchmark(unittest.TestCase):
40-
@staticmethod
41-
def build(
42-
world_size: int, num_tables: int
43-
) -> Tuple[EmbeddingEnumerator, nn.Module]:
44-
compute_device = "cuda"
45-
topology = Topology(
46-
world_size=world_size, local_world_size=8, compute_device=compute_device
79+
def build_model_and_enumerator(
80+
world_size: int,
81+
num_tables: int,
82+
embedding_dim: int = 128,
83+
local_world_size: int = 8,
84+
compute_device: str = "cuda",
85+
) -> Tuple[EmbeddingEnumerator, nn.Module]:
86+
"""
87+
Build an enumerator and model for benchmarking.
88+
89+
Args:
90+
world_size: Number of devices in the topology
91+
num_tables: Number of embedding tables in the model
92+
embedding_dim: Dimension of each embedding vector
93+
local_world_size: Number of devices per node
94+
compute_device: Device type ("cuda" or "cpu")
95+
96+
Returns:
97+
Tuple of (enumerator, model)
98+
"""
99+
topology = Topology(
100+
world_size=world_size,
101+
local_world_size=local_world_size,
102+
compute_device=compute_device,
103+
)
104+
tables = [
105+
EmbeddingBagConfig(
106+
num_embeddings=100 + i,
107+
embedding_dim=embedding_dim,
108+
name="table_" + str(i),
109+
feature_names=["feature_" + str(i)],
110+
)
111+
for i in range(num_tables)
112+
]
113+
model = TestSparseNN(tables=tables, weighted_tables=[])
114+
enumerator = EmbeddingEnumerator(topology=topology, batch_size=BATCH_SIZE)
115+
return enumerator, model
116+
117+
118+
def measure_memory_and_time(
119+
world_size: int,
120+
num_tables: int,
121+
embedding_dim: int = 128,
122+
sharder_class: Type[ModuleSharder[nn.Module]] = TWSharder,
123+
) -> Dict[str, float]:
124+
"""
125+
Measure both time and memory usage for the enumerate operation.
126+
127+
Args:
128+
world_size: Number of devices in the topology
129+
num_tables: Number of embedding tables in the model
130+
embedding_dim: Dimension of each embedding vector
131+
sharder_class: The sharder class to use
132+
133+
Returns:
134+
Dictionary with time and memory metrics
135+
"""
136+
# Force garbage collection before measurement
137+
gc.collect()
138+
139+
# Build model and enumerator
140+
enumerator, model = build_model_and_enumerator(
141+
world_size=world_size,
142+
num_tables=num_tables,
143+
embedding_dim=embedding_dim,
144+
)
145+
146+
# Force garbage collection again after model building
147+
gc.collect()
148+
149+
# Get initial memory usage
150+
initial_memory = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
151+
152+
# Measure time
153+
start_time = time.time()
154+
sharding_options = enumerator.enumerate(module=model, sharders=[sharder_class()])
155+
end_time = time.time()
156+
elapsed_time = end_time - start_time
157+
158+
# Get peak memory usage
159+
peak_memory = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
160+
# Calculate memory used during operation
161+
memory_used = peak_memory - initial_memory
162+
163+
# Convert to MB (note: ru_maxrss is in KB on Linux, bytes on macOS)
164+
# We'll assume Linux here, so divide by 1024 to get MB
165+
peak_mb = memory_used / 1024
166+
167+
# Verify the result
168+
assert len(sharding_options) == num_tables, "Unexpected number of sharding options"
169+
170+
# Convert time to milliseconds
171+
elapsed_time_ms = elapsed_time * 1000
172+
173+
return {
174+
"time_ms": elapsed_time_ms,
175+
"memory_mb": peak_mb,
176+
"options_count": len(sharding_options),
177+
}
178+
179+
180+
def benchmark_enumerator_comprehensive(
181+
sharder_class: Type[ModuleSharder[nn.Module]] = TWSharder,
182+
) -> None:
183+
"""
184+
Comprehensive benchmark testing all combinations of world sizes and table counts.
185+
Tests world sizes from 16 to 2048 and table counts from 200 to 6400.
186+
187+
Args:
188+
sharder_class: The sharder class to use for benchmarking
189+
"""
190+
# Define the ranges for world sizes and table counts
191+
world_sizes = [16, 32, 64, 128, 256, 512, 1024, 2048]
192+
table_counts = [200, 400, 800, 1600, 3200, 6400]
193+
# Create a matrix to store results
194+
results = {}
195+
196+
sharder_name = sharder_class.__name__
197+
logger.info(f"Running comprehensive enumerator benchmark with {sharder_name}...")
198+
logger.info(
199+
f"Testing {len(world_sizes)} world sizes × {len(table_counts)} table counts = {len(world_sizes) * len(table_counts)} combinations"
200+
)
201+
202+
# Track progress
203+
total_combinations = len(world_sizes) * len(table_counts)
204+
completed = 0
205+
206+
# Run benchmarks for all combinations
207+
for world_size in world_sizes:
208+
logger.info(f"Starting benchmarks for world_size={world_size}...")
209+
results[world_size] = {}
210+
world_size_start_time = time.time()
211+
212+
# Run all table counts for this world size
213+
for num_tables in table_counts:
214+
try:
215+
metrics = measure_memory_and_time(
216+
world_size=world_size,
217+
num_tables=num_tables,
218+
sharder_class=sharder_class,
219+
)
220+
results[world_size][num_tables] = metrics
221+
except Exception as e:
222+
results[world_size][num_tables] = {
223+
"time_ms": -1,
224+
"memory_mb": -1,
225+
"options_count": -1,
226+
"error": str(e),
227+
}
228+
229+
completed += 1
230+
231+
# Log completion of all table counts for this world size
232+
world_size_elapsed = time.time() - world_size_start_time
233+
logger.info(
234+
f"Completed world_size={world_size} ({len(table_counts)} table counts) "
235+
f"in {world_size_elapsed:.2f}s ({completed}/{total_combinations} combinations done)"
47236
)
48-
tables = [
49-
EmbeddingBagConfig(
50-
num_embeddings=100 + i,
51-
embedding_dim=128,
52-
name="table_" + str(i),
53-
feature_names=["feature_" + str(i)],
54-
)
55-
for i in range(num_tables)
56-
]
57-
model = TestSparseNN(tables=tables, weighted_tables=[])
58-
enumerator = EmbeddingEnumerator(topology=topology, batch_size=BATCH_SIZE)
59-
return enumerator, model
60-
61-
def measure(self, world_size: int, num_tables: int) -> float:
62-
enumerator, model = TestEnumeratorBenchmark.build(world_size, num_tables)
63-
64-
start_time = time.time()
65-
sharding_options = enumerator.enumerate(module=model, sharders=[TWSharder()])
66-
end_time = time.time()
67-
68-
self.assertEqual(len(sharding_options), num_tables)
69-
return end_time - start_time
70-
71-
def test_benchmark(self) -> None:
72-
tests = [(2048, d) for d in [100, 200, 400, 800, 1600, 3200, 6400]]
73-
print("\nEnumerator benchmark:")
74-
for world_size, num_tables in tests:
75-
t = self.measure(world_size, num_tables)
76-
print(
77-
f"world_size={world_size:8} num_tables={num_tables:8} enumerate={t:4.2f}s"
78-
)
237+
238+
# Print intermediate results for this world size
239+
logger.info(f"Results for world_size={world_size}:")
240+
logger.info(f"{'Table Count':<12} {'Time (ms)':<10} {'Memory (MB)':<12}")
241+
logger.info("-" * 35)
242+
for num_tables in table_counts:
243+
if results[world_size][num_tables].get("error"):
244+
logger.info(f"{num_tables:<12} {'ERROR':<10} {'ERROR':<12}")
245+
else:
246+
logger.info(
247+
f"{num_tables:<12} "
248+
f"{results[world_size][num_tables]['time_ms']:<10.2f} "
249+
f"{results[world_size][num_tables]['memory_mb']:<12.2f}"
250+
)
251+
252+
# Print summary table after all tests are complete
253+
logger.info(f"\nComprehensive Enumerator Benchmark with {sharder_name} - Results:")
254+
255+
# Print header row with table counts
256+
header = "World Size"
257+
for num_tables in table_counts:
258+
header += f" | {num_tables:>8}"
259+
logger.info(header)
260+
logger.info("-" * len(header))
261+
262+
# Print time results
263+
logger.info("\nTime (milliseconds):")
264+
for world_size in world_sizes:
265+
row = f"{world_size:>10}"
266+
for num_tables in table_counts:
267+
if results[world_size][num_tables].get("error"):
268+
row += f" | {'ERROR':>8}"
269+
else:
270+
row += f" | {results[world_size][num_tables]['time_ms']:>8.2f}"
271+
logger.info(row)
272+
273+
# Print memory results
274+
logger.info("\nMemory (MB):")
275+
for world_size in world_sizes:
276+
row = f"{world_size:>10}"
277+
for num_tables in table_counts:
278+
if results[world_size][num_tables].get("error"):
279+
row += f" | {'ERROR':>8}"
280+
else:
281+
row += f" | {results[world_size][num_tables]['memory_mb']:>8.2f}"
282+
logger.info(row)
79283

80284

81285
def main() -> None:
82-
unittest.main()
286+
"""
287+
Main entry point for the benchmark script.
288+
289+
Provides a command-line interface to run specific benchmarks.
290+
"""
291+
parser = argparse.ArgumentParser(description="Run planner enumerator benchmarks")
292+
parser.add_argument(
293+
"--sharder",
294+
type=str,
295+
choices=["tw", "rw", "both"],
296+
default="tw",
297+
help="Sharder type to use: table-wise (tw), row-wise (rw), or both",
298+
)
299+
logger.warning("Running planner enumerator benchmarks...")
300+
301+
args = parser.parse_args()
302+
303+
# Run benchmark with specified sharder(s)
304+
if args.sharder == "tw" or args.sharder == "both":
305+
benchmark_enumerator_comprehensive(TWSharder)
306+
307+
if args.sharder == "rw" or args.sharder == "both":
308+
benchmark_enumerator_comprehensive(RWSharder)
83309

84310

85-
# This is structured as a unitttest like file so you can use its built-in command
86-
# line argument parsing to control which benchmarks to run, e.g. "-k Enumerator"
87311
if __name__ == "__main__":
88-
main() # pragma: no cover
312+
main()

0 commit comments

Comments
 (0)