|
7 | 7 |
|
8 | 8 | # pyre-strict
|
9 | 9 |
|
10 |
| -"""Stress tests for planner to find problematic scaling behavior.""" |
| 10 | +""" |
| 11 | +Comprehensive benchmarks for planner enumerator to analyze performance and scaling behavior. |
11 | 12 |
|
12 |
| -import time |
13 |
| -import unittest |
| 13 | +This module provides benchmarks for the EmbeddingEnumerator component, including: |
| 14 | +- Performance with varying table counts |
| 15 | +- Performance with varying world sizes |
| 16 | +- Memory usage tracking |
| 17 | +""" |
14 | 18 |
|
15 |
| -from typing import List, Tuple |
| 19 | +import argparse |
| 20 | +import gc |
| 21 | +import logging |
| 22 | +import resource |
| 23 | +import time |
| 24 | +from typing import Dict, List, Tuple, Type |
16 | 25 |
|
17 | 26 | from torch import nn
|
18 |
| - |
19 | 27 | from torchrec.distributed.embedding_types import EmbeddingComputeKernel
|
20 | 28 | from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
|
21 | 29 | from torchrec.distributed.planner.constants import BATCH_SIZE
|
|
25 | 33 | from torchrec.distributed.types import ModuleSharder, ShardingType
|
26 | 34 | from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
27 | 35 |
|
| 36 | +# Configure logging to ensure visibility |
| 37 | +logging.basicConfig( |
| 38 | + level=logging.INFO, |
| 39 | + format="%(asctime)s [%(levelname)s] %(message)s", |
| 40 | + datefmt="%Y-%m-%d %H:%M:%S", |
| 41 | +) |
| 42 | +logger: logging.Logger = logging.getLogger(__name__) |
| 43 | +# Force the logger to use the configured level |
| 44 | +logger.setLevel(logging.INFO) |
| 45 | + |
28 | 46 |
|
29 | 47 | class TWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]):
|
| 48 | + """ |
| 49 | + Table-wise sharder for benchmarking. |
| 50 | + """ |
| 51 | + |
30 | 52 | def sharding_types(self, compute_device_type: str) -> List[str]:
|
| 53 | + # compute_device_type is required by the interface |
31 | 54 | return [ShardingType.TABLE_WISE.value]
|
32 | 55 |
|
33 | 56 | def compute_kernels(
|
34 | 57 | self, sharding_type: str, compute_device_type: str
|
35 | 58 | ) -> List[str]:
|
| 59 | + # sharding_type and compute_device_type are required by the interface |
| 60 | + return [EmbeddingComputeKernel.DENSE.value] |
| 61 | + |
| 62 | + |
| 63 | +class RWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]): |
| 64 | + """ |
| 65 | + Row-wise sharder for benchmarking. |
| 66 | + """ |
| 67 | + |
| 68 | + def sharding_types(self, compute_device_type: str) -> List[str]: |
| 69 | + # compute_device_type is required by the interface |
| 70 | + return [ShardingType.ROW_WISE.value] |
| 71 | + |
| 72 | + def compute_kernels( |
| 73 | + self, sharding_type: str, compute_device_type: str |
| 74 | + ) -> List[str]: |
| 75 | + # sharding_type and compute_device_type are required by the interface |
36 | 76 | return [EmbeddingComputeKernel.DENSE.value]
|
37 | 77 |
|
38 | 78 |
|
39 |
| -class TestEnumeratorBenchmark(unittest.TestCase): |
40 |
| - @staticmethod |
41 |
| - def build( |
42 |
| - world_size: int, num_tables: int |
43 |
| - ) -> Tuple[EmbeddingEnumerator, nn.Module]: |
44 |
| - compute_device = "cuda" |
45 |
| - topology = Topology( |
46 |
| - world_size=world_size, local_world_size=8, compute_device=compute_device |
| 79 | +def build_model_and_enumerator( |
| 80 | + world_size: int, |
| 81 | + num_tables: int, |
| 82 | + embedding_dim: int = 128, |
| 83 | + local_world_size: int = 8, |
| 84 | + compute_device: str = "cuda", |
| 85 | +) -> Tuple[EmbeddingEnumerator, nn.Module]: |
| 86 | + """ |
| 87 | + Build an enumerator and model for benchmarking. |
| 88 | +
|
| 89 | + Args: |
| 90 | + world_size: Number of devices in the topology |
| 91 | + num_tables: Number of embedding tables in the model |
| 92 | + embedding_dim: Dimension of each embedding vector |
| 93 | + local_world_size: Number of devices per node |
| 94 | + compute_device: Device type ("cuda" or "cpu") |
| 95 | +
|
| 96 | + Returns: |
| 97 | + Tuple of (enumerator, model) |
| 98 | + """ |
| 99 | + topology = Topology( |
| 100 | + world_size=world_size, |
| 101 | + local_world_size=local_world_size, |
| 102 | + compute_device=compute_device, |
| 103 | + ) |
| 104 | + tables = [ |
| 105 | + EmbeddingBagConfig( |
| 106 | + num_embeddings=100 + i, |
| 107 | + embedding_dim=embedding_dim, |
| 108 | + name="table_" + str(i), |
| 109 | + feature_names=["feature_" + str(i)], |
| 110 | + ) |
| 111 | + for i in range(num_tables) |
| 112 | + ] |
| 113 | + model = TestSparseNN(tables=tables, weighted_tables=[]) |
| 114 | + enumerator = EmbeddingEnumerator(topology=topology, batch_size=BATCH_SIZE) |
| 115 | + return enumerator, model |
| 116 | + |
| 117 | + |
| 118 | +def measure_memory_and_time( |
| 119 | + world_size: int, |
| 120 | + num_tables: int, |
| 121 | + embedding_dim: int = 128, |
| 122 | + sharder_class: Type[ModuleSharder[nn.Module]] = TWSharder, |
| 123 | +) -> Dict[str, float]: |
| 124 | + """ |
| 125 | + Measure both time and memory usage for the enumerate operation. |
| 126 | +
|
| 127 | + Args: |
| 128 | + world_size: Number of devices in the topology |
| 129 | + num_tables: Number of embedding tables in the model |
| 130 | + embedding_dim: Dimension of each embedding vector |
| 131 | + sharder_class: The sharder class to use |
| 132 | +
|
| 133 | + Returns: |
| 134 | + Dictionary with time and memory metrics |
| 135 | + """ |
| 136 | + # Force garbage collection before measurement |
| 137 | + gc.collect() |
| 138 | + |
| 139 | + # Build model and enumerator |
| 140 | + enumerator, model = build_model_and_enumerator( |
| 141 | + world_size=world_size, |
| 142 | + num_tables=num_tables, |
| 143 | + embedding_dim=embedding_dim, |
| 144 | + ) |
| 145 | + |
| 146 | + # Force garbage collection again after model building |
| 147 | + gc.collect() |
| 148 | + |
| 149 | + # Get initial memory usage |
| 150 | + initial_memory = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss |
| 151 | + |
| 152 | + # Measure time |
| 153 | + start_time = time.time() |
| 154 | + sharding_options = enumerator.enumerate(module=model, sharders=[sharder_class()]) |
| 155 | + end_time = time.time() |
| 156 | + elapsed_time = end_time - start_time |
| 157 | + |
| 158 | + # Get peak memory usage |
| 159 | + peak_memory = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss |
| 160 | + # Calculate memory used during operation |
| 161 | + memory_used = peak_memory - initial_memory |
| 162 | + |
| 163 | + # Convert to MB (note: ru_maxrss is in KB on Linux, bytes on macOS) |
| 164 | + # We'll assume Linux here, so divide by 1024 to get MB |
| 165 | + peak_mb = memory_used / 1024 |
| 166 | + |
| 167 | + # Verify the result |
| 168 | + assert len(sharding_options) == num_tables, "Unexpected number of sharding options" |
| 169 | + |
| 170 | + # Convert time to milliseconds |
| 171 | + elapsed_time_ms = elapsed_time * 1000 |
| 172 | + |
| 173 | + return { |
| 174 | + "time_ms": elapsed_time_ms, |
| 175 | + "memory_mb": peak_mb, |
| 176 | + "options_count": len(sharding_options), |
| 177 | + } |
| 178 | + |
| 179 | + |
| 180 | +def benchmark_enumerator_comprehensive( |
| 181 | + sharder_class: Type[ModuleSharder[nn.Module]] = TWSharder, |
| 182 | +) -> None: |
| 183 | + """ |
| 184 | + Comprehensive benchmark testing all combinations of world sizes and table counts. |
| 185 | + Tests world sizes from 16 to 2048 and table counts from 200 to 6400. |
| 186 | +
|
| 187 | + Args: |
| 188 | + sharder_class: The sharder class to use for benchmarking |
| 189 | + """ |
| 190 | + # Define the ranges for world sizes and table counts |
| 191 | + world_sizes = [16, 32, 64, 128, 256, 512, 1024, 2048] |
| 192 | + table_counts = [200, 400, 800, 1600, 3200, 6400] |
| 193 | + # Create a matrix to store results |
| 194 | + results = {} |
| 195 | + |
| 196 | + sharder_name = sharder_class.__name__ |
| 197 | + logger.info(f"Running comprehensive enumerator benchmark with {sharder_name}...") |
| 198 | + logger.info( |
| 199 | + f"Testing {len(world_sizes)} world sizes × {len(table_counts)} table counts = {len(world_sizes) * len(table_counts)} combinations" |
| 200 | + ) |
| 201 | + |
| 202 | + # Track progress |
| 203 | + total_combinations = len(world_sizes) * len(table_counts) |
| 204 | + completed = 0 |
| 205 | + |
| 206 | + # Run benchmarks for all combinations |
| 207 | + for world_size in world_sizes: |
| 208 | + logger.info(f"Starting benchmarks for world_size={world_size}...") |
| 209 | + results[world_size] = {} |
| 210 | + world_size_start_time = time.time() |
| 211 | + |
| 212 | + # Run all table counts for this world size |
| 213 | + for num_tables in table_counts: |
| 214 | + try: |
| 215 | + metrics = measure_memory_and_time( |
| 216 | + world_size=world_size, |
| 217 | + num_tables=num_tables, |
| 218 | + sharder_class=sharder_class, |
| 219 | + ) |
| 220 | + results[world_size][num_tables] = metrics |
| 221 | + except Exception as e: |
| 222 | + results[world_size][num_tables] = { |
| 223 | + "time_ms": -1, |
| 224 | + "memory_mb": -1, |
| 225 | + "options_count": -1, |
| 226 | + "error": str(e), |
| 227 | + } |
| 228 | + |
| 229 | + completed += 1 |
| 230 | + |
| 231 | + # Log completion of all table counts for this world size |
| 232 | + world_size_elapsed = time.time() - world_size_start_time |
| 233 | + logger.info( |
| 234 | + f"Completed world_size={world_size} ({len(table_counts)} table counts) " |
| 235 | + f"in {world_size_elapsed:.2f}s ({completed}/{total_combinations} combinations done)" |
47 | 236 | )
|
48 |
| - tables = [ |
49 |
| - EmbeddingBagConfig( |
50 |
| - num_embeddings=100 + i, |
51 |
| - embedding_dim=128, |
52 |
| - name="table_" + str(i), |
53 |
| - feature_names=["feature_" + str(i)], |
54 |
| - ) |
55 |
| - for i in range(num_tables) |
56 |
| - ] |
57 |
| - model = TestSparseNN(tables=tables, weighted_tables=[]) |
58 |
| - enumerator = EmbeddingEnumerator(topology=topology, batch_size=BATCH_SIZE) |
59 |
| - return enumerator, model |
60 |
| - |
61 |
| - def measure(self, world_size: int, num_tables: int) -> float: |
62 |
| - enumerator, model = TestEnumeratorBenchmark.build(world_size, num_tables) |
63 |
| - |
64 |
| - start_time = time.time() |
65 |
| - sharding_options = enumerator.enumerate(module=model, sharders=[TWSharder()]) |
66 |
| - end_time = time.time() |
67 |
| - |
68 |
| - self.assertEqual(len(sharding_options), num_tables) |
69 |
| - return end_time - start_time |
70 |
| - |
71 |
| - def test_benchmark(self) -> None: |
72 |
| - tests = [(2048, d) for d in [100, 200, 400, 800, 1600, 3200, 6400]] |
73 |
| - print("\nEnumerator benchmark:") |
74 |
| - for world_size, num_tables in tests: |
75 |
| - t = self.measure(world_size, num_tables) |
76 |
| - print( |
77 |
| - f"world_size={world_size:8} num_tables={num_tables:8} enumerate={t:4.2f}s" |
78 |
| - ) |
| 237 | + |
| 238 | + # Print intermediate results for this world size |
| 239 | + logger.info(f"Results for world_size={world_size}:") |
| 240 | + logger.info(f"{'Table Count':<12} {'Time (ms)':<10} {'Memory (MB)':<12}") |
| 241 | + logger.info("-" * 35) |
| 242 | + for num_tables in table_counts: |
| 243 | + if results[world_size][num_tables].get("error"): |
| 244 | + logger.info(f"{num_tables:<12} {'ERROR':<10} {'ERROR':<12}") |
| 245 | + else: |
| 246 | + logger.info( |
| 247 | + f"{num_tables:<12} " |
| 248 | + f"{results[world_size][num_tables]['time_ms']:<10.2f} " |
| 249 | + f"{results[world_size][num_tables]['memory_mb']:<12.2f}" |
| 250 | + ) |
| 251 | + |
| 252 | + # Print summary table after all tests are complete |
| 253 | + logger.info(f"\nComprehensive Enumerator Benchmark with {sharder_name} - Results:") |
| 254 | + |
| 255 | + # Print header row with table counts |
| 256 | + header = "World Size" |
| 257 | + for num_tables in table_counts: |
| 258 | + header += f" | {num_tables:>8}" |
| 259 | + logger.info(header) |
| 260 | + logger.info("-" * len(header)) |
| 261 | + |
| 262 | + # Print time results |
| 263 | + logger.info("\nTime (milliseconds):") |
| 264 | + for world_size in world_sizes: |
| 265 | + row = f"{world_size:>10}" |
| 266 | + for num_tables in table_counts: |
| 267 | + if results[world_size][num_tables].get("error"): |
| 268 | + row += f" | {'ERROR':>8}" |
| 269 | + else: |
| 270 | + row += f" | {results[world_size][num_tables]['time_ms']:>8.2f}" |
| 271 | + logger.info(row) |
| 272 | + |
| 273 | + # Print memory results |
| 274 | + logger.info("\nMemory (MB):") |
| 275 | + for world_size in world_sizes: |
| 276 | + row = f"{world_size:>10}" |
| 277 | + for num_tables in table_counts: |
| 278 | + if results[world_size][num_tables].get("error"): |
| 279 | + row += f" | {'ERROR':>8}" |
| 280 | + else: |
| 281 | + row += f" | {results[world_size][num_tables]['memory_mb']:>8.2f}" |
| 282 | + logger.info(row) |
79 | 283 |
|
80 | 284 |
|
81 | 285 | def main() -> None:
|
82 |
| - unittest.main() |
| 286 | + """ |
| 287 | + Main entry point for the benchmark script. |
| 288 | +
|
| 289 | + Provides a command-line interface to run specific benchmarks. |
| 290 | + """ |
| 291 | + parser = argparse.ArgumentParser(description="Run planner enumerator benchmarks") |
| 292 | + parser.add_argument( |
| 293 | + "--sharder", |
| 294 | + type=str, |
| 295 | + choices=["tw", "rw", "both"], |
| 296 | + default="tw", |
| 297 | + help="Sharder type to use: table-wise (tw), row-wise (rw), or both", |
| 298 | + ) |
| 299 | + logger.warning("Running planner enumerator benchmarks...") |
| 300 | + |
| 301 | + args = parser.parse_args() |
| 302 | + |
| 303 | + # Run benchmark with specified sharder(s) |
| 304 | + if args.sharder == "tw" or args.sharder == "both": |
| 305 | + benchmark_enumerator_comprehensive(TWSharder) |
| 306 | + |
| 307 | + if args.sharder == "rw" or args.sharder == "both": |
| 308 | + benchmark_enumerator_comprehensive(RWSharder) |
83 | 309 |
|
84 | 310 |
|
85 |
| -# This is structured as a unitttest like file so you can use its built-in command |
86 |
| -# line argument parsing to control which benchmarks to run, e.g. "-k Enumerator" |
87 | 311 | if __name__ == "__main__":
|
88 |
| - main() # pragma: no cover |
| 312 | + main() |
0 commit comments