34
34
e5m2_dtype ,
35
35
)
36
36
from torchao .float8 .float8_linear import Float8Linear
37
- from torchao .float8 .float8_linear_utils import (
38
- convert_to_float8_training ,
39
- )
37
+ from torchao .float8 .float8_linear_utils import convert_to_float8_training
40
38
from torchao .float8 .float8_ops import addmm_float8_unwrapped
41
39
from torchao .float8 .float8_scaling_utils import (
42
40
get_maybe_axiswise_dim ,
@@ -379,12 +377,16 @@ def test_linear_from_config_params(
379
377
)
380
378
@pytest .mark .parametrize ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )])
381
379
@pytest .mark .parametrize ("linear_bias" , [True , False ])
380
+ @pytest .mark .parametrize (
381
+ "linear_dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ]
382
+ )
382
383
@unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
383
384
@skip_if_rocm ("ROCm enablement in progress" )
384
385
def test_linear_from_recipe (
385
386
self ,
386
387
recipe_name ,
387
388
x_shape ,
389
+ linear_dtype : torch .dtype ,
388
390
linear_bias : bool ,
389
391
):
390
392
if torch .cuda .get_device_capability () < (9 , 0 ):
@@ -393,7 +395,6 @@ def test_linear_from_recipe(
393
395
)
394
396
pytest .skip ()
395
397
396
- linear_dtype = torch .bfloat16
397
398
x = torch .randn (* x_shape , device = "cuda" , dtype = linear_dtype )
398
399
m_ref = nn .Linear (16 , 32 , bias = linear_bias , device = "cuda" , dtype = linear_dtype )
399
400
config = Float8LinearConfig .from_recipe_name (recipe_name )
0 commit comments