11
11
import torch
12
12
import torch .nn as nn
13
13
14
+ from torchao .float8 .float8_utils import is_row_major
14
15
from torchao .prototype .mx_formats .config import (
15
16
MXLinearConfig ,
16
17
MXLinearRecipeName ,
24
25
)
25
26
from torchao .quantization .utils import compute_error
26
27
from torchao .utils import (
27
- TORCH_VERSION_AT_LEAST_2_4 ,
28
+ TORCH_VERSION_AT_LEAST_2_5 ,
28
29
is_sm_at_least_89 ,
29
30
is_sm_at_least_100 ,
30
31
)
31
32
32
33
torch .manual_seed (2 )
33
34
34
- if not TORCH_VERSION_AT_LEAST_2_4 :
35
+ if not TORCH_VERSION_AT_LEAST_2_5 :
35
36
pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
36
37
37
38
@@ -169,11 +170,18 @@ def test_activation_checkpointing():
169
170
170
171
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
171
172
@pytest .mark .skipif (
172
- is_sm_at_least_100 (), reason = "triton does not work yet on CUDA capability 10.0"
173
+ is_sm_at_least_100 (),
174
+ reason = "triton does not work yet on CUDA capability 10.0" ,
173
175
)
174
176
@pytest .mark .parametrize (
175
177
"recipe_name" ,
176
- ["mxfp8_emulated" , "mxfp4_emulated" , "mxfp8_cutlass" , "mxfp4_cutlass" ],
178
+ [
179
+ "mxfp8_emulated" ,
180
+ "mxfp4_emulated" ,
181
+ "mxfp8_cublas" ,
182
+ "mxfp8_cutlass" ,
183
+ "mxfp4_cutlass" ,
184
+ ],
177
185
)
178
186
@pytest .mark .parametrize ("bias" , [False , True ])
179
187
# TODO(future PR): figure out why torch.compile does not match eager when
@@ -186,13 +194,13 @@ def test_linear_compile(recipe_name, bias):
186
194
if not is_sm_at_least_89 ():
187
195
pytest .skip ("CUDA capability >= 8.9 required for float8 in triton" )
188
196
189
- if recipe_name in ["mxfp8_cutlass" , "mxfp4_cutlass" ]:
197
+ if recipe_name in ["mxfp8_cublas" , " mxfp8_cutlass" , "mxfp4_cutlass" ]:
190
198
if not is_sm_at_least_100 ():
191
199
pytest .skip ("CUDA capability >= 10.0 required for MX gemms" )
192
200
193
- if bias and recipe_name in ["mxfp8_cutlass" , "mxfp4_cutlass" ]:
201
+ if bias and recipe_name in ["mxfp8_cublas" , " mxfp8_cutlass" , "mxfp4_cutlass" ]:
194
202
# TODO(future PR): fix this, things are clearly broken with bias=True
195
- pytest .skip ("this test is broken for cutlass recipes with bias=True" )
203
+ pytest .skip ("this test is broken for non-emulated recipes with bias=True" )
196
204
197
205
M , K , N = 128 , 256 , 512
198
206
input_shape = (M , K )
@@ -285,6 +293,61 @@ def test_inference_compile_simple(elem_dtype):
285
293
assert sqnr >= 13.5
286
294
287
295
296
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
297
+ @pytest .mark .skipif (
298
+ is_sm_at_least_100 (),
299
+ reason = "triton does not work yet on CUDA capability 10.0" ,
300
+ )
301
+ @pytest .mark .skipif (
302
+ not is_sm_at_least_100 (),
303
+ reason = "MX gemms require CUDA capability 10.0" ,
304
+ )
305
+ def test_scaled_mm_wrapper ():
306
+ # today, e8m0 isn't supported in torchinductor or triton
307
+ # for now, work around this by creating a wrapper around torch._scaled_mm
308
+ # which takes uint8 scales, and reinterprets them as e8m0 inside the wrapper
309
+ from torchao .prototype .mx_formats .mx_ops import _scaled_mm_with_uint8_scales
310
+
311
+ M , K , N = 128 , 256 , 512
312
+ BLOCK_SIZE = 32
313
+ a = torch .randn (M , K , device = "cuda" ).to (torch .float8_e4m3fn )
314
+ b = torch .randn (N , K , device = "cuda" ).to (torch .float8_e4m3fn )
315
+
316
+ a_scale = torch .ones (M , K // BLOCK_SIZE , device = "cuda" , dtype = torch .float8_e8m0fnu )
317
+ b_scale = torch .ones (N , K // BLOCK_SIZE , device = "cuda" , dtype = torch .float8_e8m0fnu )
318
+
319
+ out = torch ._scaled_mm (a , b .t (), a_scale , b_scale , out_dtype = torch .bfloat16 )
320
+
321
+ def wrapped (a , b , a_scale , b_scale , out_dtype ):
322
+ if is_row_major (b .stride ()):
323
+ b = b .t ().contiguous ().t ()
324
+ res = _scaled_mm_with_uint8_scales (a , b , a_scale , b_scale , out_dtype = out_dtype )
325
+ return res
326
+
327
+ wrapped = torch .compile (wrapped )
328
+
329
+ # correct memory format of `b`
330
+ out2 = wrapped (
331
+ a ,
332
+ b .t (),
333
+ a_scale .view (torch .uint8 ),
334
+ b_scale .view (torch .uint8 ),
335
+ out_dtype = torch .bfloat16 ,
336
+ )
337
+ torch .testing .assert_close (out , out2 , atol = 0 , rtol = 0 )
338
+
339
+ # incorrect memory format of `b`
340
+ b_col_major = b .t ().contiguous ().t ()
341
+ out3 = wrapped (
342
+ a ,
343
+ b_col_major .t (),
344
+ a_scale .view (torch .uint8 ),
345
+ b_scale .view (torch .uint8 ),
346
+ out_dtype = torch .bfloat16 ,
347
+ )
348
+ torch .testing .assert_close (out , out3 , atol = 0 , rtol = 0 )
349
+
350
+
288
351
def test_filter_fn ():
289
352
m1 = nn .Sequential (
290
353
nn .Linear (32 , 32 ),
0 commit comments