3939
4040
4141class  ToyLinearModel (torch .nn .Module ):
42-     def  __init__ (self , in_features , out_features ):
42+     def  __init__ (self , in_features , out_features ,  bias ):
4343        super ().__init__ ()
44-         self .linear1  =  torch .nn .Linear (in_features , out_features , bias = False )
45-         self .linear2  =  torch .nn .Linear (out_features , in_features , bias = False )
44+         self .linear1  =  torch .nn .Linear (in_features , out_features , bias = bias )
45+         self .linear2  =  torch .nn .Linear (out_features , in_features , bias = bias )
4646
4747    def  forward (self , x ):
4848        x  =  self .linear1 (x )
@@ -81,6 +81,8 @@ def setUp(self):
8181            ((32 , 128 ), 256 , 512 ), 
8282        ], 
8383    ) 
84+     @common_utils .parametrize ("bias" , [False , True ]) 
85+     @torch .no_grad () 
8486    def  test_fp8_linear_variants (
8587        self ,
8688        dtype : torch .dtype ,
@@ -89,6 +91,7 @@ def test_fp8_linear_variants(
8991        granularity ,
9092        kernel_preference : KernelPreference ,
9193        sizes : Tuple ,
94+         bias : bool ,
9295    ):
9396        if  isinstance (granularity , PerTensor ):
9497            if  kernel_preference  is  KernelPreference .FBGEMM :
@@ -106,6 +109,16 @@ def test_fp8_linear_variants(
106109            elif  kernel_preference  is  KernelPreference .FBGEMM :
107110                return  unittest .skip ("unimplemented" )
108111
112+             if  bias  is  True :
113+                 sizes_to_keep  =  ((128 ,), 256 , 128 )
114+                 if  (
115+                     sizes  !=  sizes_to_keep 
116+                     or  kernel_preference  is  not KernelPreference .TORCH 
117+                 ):
118+                     return  unittest .skip (
119+                         "cut down on number of options to save test time" 
120+                     )
121+ 
109122        error_message  =  None 
110123        if  isinstance (granularity , PerRow ):
111124            if  mode  ==  "dynamic"  and  dtype  !=  torch .bfloat16 :
@@ -134,7 +147,7 @@ def test_fp8_linear_variants(
134147            input_tensor  =  torch .randn (* M , K , dtype = dtype , device = "cuda" )
135148
136149            # Create a linear layer with bfloat16 dtype 
137-             model  =  ToyLinearModel (K , N ).eval ().to (dtype ).to ("cuda" )
150+             model  =  ToyLinearModel (K , N ,  bias ).eval ().to (dtype ).to ("cuda" )
138151
139152            quantized_model  =  copy .deepcopy (model )
140153
@@ -257,7 +270,7 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
257270        dtype  =  torch .bfloat16 
258271        input_tensor  =  torch .randn (* M , K , dtype = dtype , device = "cuda" )
259272        # Create a linear layer with bfloat16 dtype 
260-         model  =  ToyLinearModel (K , N ).eval ().to (dtype ).to ("cuda" )
273+         model  =  ToyLinearModel (K , N ,  bias = False ).eval ().to (dtype ).to ("cuda" )
261274
262275        # reference kernel preference and results 
263276        # we are using KerenelPreference.TORCH as the reference 
0 commit comments