@@ -334,3 +334,55 @@ def test_to_mx_inductor_single_kernel():
334
334
to_mx_c = torch .compile (MXTensor .to_mx , fullgraph = True )
335
335
out , code = run_and_get_code (to_mx_c , x , torch .float8_e4m3fn , block_size )
336
336
FileCheck ().check ("def call(" ).check_count (".run(" , 1 , exactly = True ).run (code [0 ])
337
+
338
+
339
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
340
+ @pytest .mark .skipif (
341
+ not is_sm_at_least_89 (),
342
+ reason = "float8 in triton requires CUDA capability 8.9 or greater" ,
343
+ )
344
+ def test_cast_to_float8_e4m3fn_saturation_behavior ():
345
+ # TODO(#1912): make the saturated cast work in eager mode and remove this
346
+ # test
347
+ max_val = torch .finfo (torch .float8_e4m3fn ).max
348
+
349
+ # create example data inside the representable range
350
+ data_in_range_bf16 = torch .tensor (
351
+ [
352
+ max_val ,
353
+ - 1 * max_val ,
354
+ ],
355
+ dtype = torch .bfloat16 ,
356
+ device = "cuda" ,
357
+ )
358
+
359
+ # create example data outside the representable range
360
+ data_out_of_range_bf16 = torch .tensor (
361
+ [
362
+ max_val * 2 ,
363
+ - 1 * (max_val * 2 ),
364
+ ],
365
+ dtype = torch .bfloat16 ,
366
+ device = "cuda" ,
367
+ )
368
+
369
+ # verify that in eager mode PyTorch casting to float8 is unsaturated
370
+ data_in_range_f8 = data_in_range_bf16 .to (torch .float8_e4m3fn )
371
+ data_out_of_range_f8 = data_out_of_range_bf16 .to (torch .float8_e4m3fn )
372
+ assert not torch .any (torch .isnan (data_in_range_f8 ))
373
+ assert torch .all (torch .isnan (data_out_of_range_f8 ))
374
+
375
+ # verify that in triton, casting to float8 is saturated
376
+ # for simplicity, use torch.compile to generate triton code
377
+ def to_f8 (x ):
378
+ x = x .to (torch .float8_e4m3fn )
379
+ return x
380
+
381
+ to_f8_c = torch .compile (to_f8 )
382
+ data_in_range_f8_c = to_f8_c (data_in_range_bf16 )
383
+ data_out_of_range_f8_c = to_f8_c (data_out_of_range_bf16 )
384
+ assert not torch .any (torch .isnan (data_in_range_f8_c ))
385
+ assert not torch .any (torch .isnan (data_out_of_range_f8_c ))
386
+ torch .testing .assert_close (
387
+ data_in_range_f8_c , data_out_of_range_f8_c , atol = 0 , rtol = 0
388
+ )
0 commit comments