Skip to content

Commit a886a27

Browse files
committed
Update
[ghstack-poisoned]
1 parent 65875d1 commit a886a27

File tree

3 files changed

+110
-64
lines changed

3 files changed

+110
-64
lines changed

benchmarks/microbenchmarks/test/test_benchmark_runner.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -63,40 +63,56 @@ def test_get_shapes_for_config(self):
6363
)
6464
self.assertEqual(len(shapes), 1)
6565
self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024]))
66-
66+
6767
# Test llama shapes
68-
llama_shapes = get_shapes_for_config([
69-
{"name": "llama"}
70-
])
68+
llama_shapes = get_shapes_for_config([{"name": "llama"}])
7169
self.assertEqual(len(llama_shapes), 4) # 4 LLaMa shapes
72-
self.assertTrue(any(name.startswith("llama_attn.wqkv") for name, _ in llama_shapes))
73-
self.assertTrue(any(name.startswith("llama_attn.w0") for name, _ in llama_shapes))
74-
self.assertTrue(any(name.startswith("llama_ffn.w13") for name, _ in llama_shapes))
75-
self.assertTrue(any(name.startswith("llama_ffn.w2") for name, _ in llama_shapes))
76-
70+
self.assertTrue(
71+
any(name.startswith("llama_attn.wqkv") for name, _ in llama_shapes)
72+
)
73+
self.assertTrue(
74+
any(name.startswith("llama_attn.w0") for name, _ in llama_shapes)
75+
)
76+
self.assertTrue(
77+
any(name.startswith("llama_ffn.w13") for name, _ in llama_shapes)
78+
)
79+
self.assertTrue(
80+
any(name.startswith("llama_ffn.w2") for name, _ in llama_shapes)
81+
)
82+
7783
# Test pow2 shapes
78-
pow2_shapes = get_shapes_for_config([
79-
{"name": "pow2", "min_power": 10, "max_power": 12}
80-
])
84+
pow2_shapes = get_shapes_for_config(
85+
[{"name": "pow2", "min_power": 10, "max_power": 12}]
86+
)
8187
self.assertEqual(len(pow2_shapes), 3) # 3 powers of 2 (10, 11, 12)
8288
self.assertEqual(pow2_shapes[0], ("pow2_0", [1024, 1024, 1024])) # 2^10
8389
self.assertEqual(pow2_shapes[1], ("pow2_1", [2048, 2048, 2048])) # 2^11
8490
self.assertEqual(pow2_shapes[2], ("pow2_2", [4096, 4096, 4096])) # 2^12
85-
91+
8692
# Test pow2_extended shapes
87-
pow2_extended_shapes = get_shapes_for_config([
88-
{"name": "pow2_extended", "min_power": 10, "max_power": 11}
89-
])
90-
self.assertEqual(len(pow2_extended_shapes), 4) # 2 powers of 2, each with 2 variants
91-
self.assertEqual(pow2_extended_shapes[0], ("pow2_extended_0", [1024, 1024, 1024])) # 2^10
92-
self.assertEqual(pow2_extended_shapes[1], ("pow2_extended_1", [1536, 1536, 1536])) # 2^10 + 2^9
93-
self.assertEqual(pow2_extended_shapes[2], ("pow2_extended_2", [2048, 2048, 2048])) # 2^11
94-
self.assertEqual(pow2_extended_shapes[3], ("pow2_extended_3", [3072, 3072, 3072])) # 2^11 + 2^10
95-
93+
pow2_extended_shapes = get_shapes_for_config(
94+
[{"name": "pow2_extended", "min_power": 10, "max_power": 11}]
95+
)
96+
self.assertEqual(
97+
len(pow2_extended_shapes), 4
98+
) # 2 powers of 2, each with 2 variants
99+
self.assertEqual(
100+
pow2_extended_shapes[0], ("pow2_extended_0", [1024, 1024, 1024])
101+
) # 2^10
102+
self.assertEqual(
103+
pow2_extended_shapes[1], ("pow2_extended_1", [1536, 1536, 1536])
104+
) # 2^10 + 2^9
105+
self.assertEqual(
106+
pow2_extended_shapes[2], ("pow2_extended_2", [2048, 2048, 2048])
107+
) # 2^11
108+
self.assertEqual(
109+
pow2_extended_shapes[3], ("pow2_extended_3", [3072, 3072, 3072])
110+
) # 2^11 + 2^10
111+
96112
# Test sweep shapes (limited to a small range for testing)
97-
sweep_shapes = get_shapes_for_config([
98-
{"name": "sweep", "min_power": 8, "max_power": 9}
99-
])
113+
sweep_shapes = get_shapes_for_config(
114+
[{"name": "sweep", "min_power": 8, "max_power": 9}]
115+
)
100116
# For min_power=8, max_power=9, we should have 8 shapes (2^3 = 8 combinations)
101117
self.assertEqual(len(sweep_shapes), 8)
102118
# Check that all shapes have the expected format

benchmarks/microbenchmarks/test/test_utils.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def test_rms_norm(self):
171171
x = torch.randn(16, 64)
172172
out = rms_norm(x)
173173
self.assertEqual(out.shape, (16, 64))
174-
174+
175175
# Test with different eps
176176
rms_norm = RMSNorm(dim=64, eps=1e-5)
177177
out = rms_norm(x)
@@ -184,38 +184,50 @@ def test_rms_norm_linear_activation(self):
184184
out = model(x)
185185
self.assertEqual(out.shape, (16, 32))
186186
self.assertEqual(out.dtype, torch.float32)
187-
187+
188188
# Test with ReLU activation
189-
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="relu")
189+
model = RMSNormLinearActivation(
190+
fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="relu"
191+
)
190192
out = model(x)
191193
self.assertEqual(out.shape, (16, 32))
192194
self.assertTrue(torch.all(out >= 0)) # Check ReLU output range
193-
195+
194196
# Test with SiLU activation
195-
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="silu")
197+
model = RMSNormLinearActivation(
198+
fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="silu"
199+
)
196200
out = model(x)
197201
self.assertEqual(out.shape, (16, 32))
198-
202+
199203
# Test with invalid activation
200204
with self.assertRaises(ValueError):
201-
RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="invalid")
205+
RMSNormLinearActivation(
206+
fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="invalid"
207+
)
202208

203209
def test_transformer_block(self):
204210
# Test with default parameters
205-
model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32)
211+
model = TransformerBlock(
212+
hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32
213+
)
206214
x = torch.randn(16, 16, 64) # [batch_size, seq_len, hidden_dim]
207215
out = model(x)
208216
self.assertEqual(out.shape, (16, 16, 64))
209217
self.assertEqual(out.dtype, torch.float32)
210-
218+
211219
# Test with different parameters
212-
model = TransformerBlock(hidden_dim=128, num_heads=4, mlp_ratio=2, dtype=torch.float32)
220+
model = TransformerBlock(
221+
hidden_dim=128, num_heads=4, mlp_ratio=2, dtype=torch.float32
222+
)
213223
x = torch.randn(8, 32, 128)
214224
out = model(x)
215225
self.assertEqual(out.shape, (8, 32, 128))
216-
226+
217227
# Test with different head dimensions
218-
model = TransformerBlock(hidden_dim=96, num_heads=6, mlp_ratio=3, dtype=torch.float32)
228+
model = TransformerBlock(
229+
hidden_dim=96, num_heads=6, mlp_ratio=3, dtype=torch.float32
230+
)
219231
x = torch.randn(4, 8, 96)
220232
out = model(x)
221233
self.assertEqual(out.shape, (4, 8, 96))
@@ -255,7 +267,7 @@ def test_create_model_and_input(self):
255267
)
256268
self.assertIsInstance(model, RMSNormLinearActivation)
257269
self.assertEqual(input_data.shape, (m, k))
258-
270+
259271
# Test TransformerBlock
260272
model, input_data = create_model_and_input(
261273
model_type="transformer_block",
@@ -266,40 +278,50 @@ def test_create_model_and_input(self):
266278
device="cpu",
267279
)
268280
self.assertIsInstance(model, TransformerBlock)
269-
self.assertEqual(input_data.shape, (m, 16, k)) # [batch_size, seq_len, hidden_dim]
281+
self.assertEqual(
282+
input_data.shape, (m, 16, k)
283+
) # [batch_size, seq_len, hidden_dim]
270284

271285
def test_quantization_on_models(self):
272286
# Test quantization on RMSNormLinearActivation
273287
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32)
274288
x = torch.randn(16, 64)
275-
289+
276290
# Test with Int8WeightOnlyConfig
277291
config = string_to_config(quantization="int8wo", sparsity=None)
278292
if config is not None:
279293
# Skip quantization test if torchao.quantization.quantize is not available
280294
try:
281295
from torchao.quantization import quantize
296+
282297
quantized_model = quantize(model, config)
283298
out = quantized_model(x)
284299
self.assertEqual(out.shape, (16, 32))
285300
except ImportError:
286-
print("Skipping quantization test: torchao.quantization.quantize not available")
287-
301+
print(
302+
"Skipping quantization test: torchao.quantization.quantize not available"
303+
)
304+
288305
# Test quantization on TransformerBlock
289-
model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32)
306+
model = TransformerBlock(
307+
hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32
308+
)
290309
x = torch.randn(16, 16, 64)
291-
310+
292311
# Test with Int8WeightOnlyConfig
293312
config = string_to_config(quantization="int8wo", sparsity=None)
294313
if config is not None:
295314
# Skip quantization test if torchao.quantization.quantize is not available
296315
try:
297316
from torchao.quantization import quantize
317+
298318
quantized_model = quantize(model, config)
299319
out = quantized_model(x)
300320
self.assertEqual(out.shape, (16, 16, 64))
301321
except ImportError:
302-
print("Skipping quantization test: torchao.quantization.quantize not available")
322+
print(
323+
"Skipping quantization test: torchao.quantization.quantize not available"
324+
)
303325

304326
def test_generate_results_csv(self):
305327
results = [

benchmarks/microbenchmarks/utils.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="gelu"):
399399
super().__init__()
400400
self.rms_norm = RMSNorm(fc_dim1, dtype=dtype)
401401
self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype)
402-
402+
403403
if activation == "gelu":
404404
self.activation = torch.nn.GELU()
405405
elif activation == "relu":
@@ -422,66 +422,72 @@ def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16):
422422
self.hidden_dim = hidden_dim
423423
self.num_heads = num_heads
424424
self.head_dim = hidden_dim // num_heads
425-
425+
426426
# Self-attention
427427
self.qkv = torch.nn.Linear(hidden_dim, 3 * hidden_dim, bias=False).to(dtype)
428428
self.proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False).to(dtype)
429-
429+
430430
# MLP
431431
self.mlp_ratio = mlp_ratio
432432
self.mlp_hidden_dim = int(hidden_dim * mlp_ratio)
433-
self.mlp_fc1 = torch.nn.Linear(hidden_dim, self.mlp_hidden_dim, bias=False).to(dtype)
434-
self.mlp_fc2 = torch.nn.Linear(self.mlp_hidden_dim, hidden_dim, bias=False).to(dtype)
435-
433+
self.mlp_fc1 = torch.nn.Linear(hidden_dim, self.mlp_hidden_dim, bias=False).to(
434+
dtype
435+
)
436+
self.mlp_fc2 = torch.nn.Linear(self.mlp_hidden_dim, hidden_dim, bias=False).to(
437+
dtype
438+
)
439+
436440
# Layer norms
437441
self.norm1 = RMSNorm(hidden_dim, dtype=dtype)
438442
self.norm2 = RMSNorm(hidden_dim, dtype=dtype)
439-
443+
440444
# Activation
441445
self.activation = torch.nn.GELU()
442446

443447
def forward(self, x):
444448
batch_size, seq_len, _ = x.shape
445-
449+
446450
# Self-attention
447451
residual = x
448452
x = self.norm1(x)
449-
453+
450454
# Reshape qkv projection for better memory layout
451455
qkv = self.qkv(x) # [batch_size, seq_len, 3 * hidden_dim]
452456
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
453-
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch_size, num_heads, seq_len, head_dim]
457+
qkv = qkv.permute(
458+
2, 0, 3, 1, 4
459+
) # [3, batch_size, num_heads, seq_len, head_dim]
454460
q, k, v = qkv # Each has shape [batch_size, num_heads, seq_len, head_dim]
455-
461+
456462
# Scaled dot-product attention with proper reshaping
457463
# Reshape for better memory layout and avoid broadcasting issues
458464
q = q.reshape(batch_size * self.num_heads, seq_len, self.head_dim)
459465
k = k.reshape(batch_size * self.num_heads, seq_len, self.head_dim)
460466
v = v.reshape(batch_size * self.num_heads, seq_len, self.head_dim)
461-
467+
462468
# Compute attention scores
463-
attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim ** 0.5))
469+
attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim**0.5))
464470
attn = torch.softmax(attn, dim=-1)
465-
471+
466472
# Apply attention to values
467473
x = attn @ v # [batch_size * num_heads, seq_len, head_dim]
468-
474+
469475
# Reshape back to original dimensions
470476
x = x.reshape(batch_size, self.num_heads, seq_len, self.head_dim)
471477
x = x.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim)
472-
478+
473479
# Project back to hidden dimension
474480
x = self.proj(x)
475481
x = residual + x
476-
482+
477483
# MLP
478484
residual = x
479485
x = self.norm2(x)
480486
x = self.mlp_fc1(x)
481487
x = self.activation(x)
482488
x = self.mlp_fc2(x)
483489
x = residual + x
484-
490+
485491
return x
486492

487493

@@ -683,7 +689,9 @@ def create_model_and_input(
683689
input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype)
684690
elif model_type == "transformer_block":
685691
# For transformer block, k is the hidden dimension
686-
model = TransformerBlock(k, num_heads=8, mlp_ratio=4, dtype=high_precision_dtype).to(device)
692+
model = TransformerBlock(
693+
k, num_heads=8, mlp_ratio=4, dtype=high_precision_dtype
694+
).to(device)
687695
# Input shape for transformer is [batch_size, seq_len, hidden_dim]
688696
input_data = torch.randn(m, 16, k, device=device, dtype=high_precision_dtype)
689697
else:

0 commit comments

Comments
 (0)