10
10
from torch .testing ._internal .common_utils import (
11
11
TestCase ,
12
12
run_tests ,
13
+ parametrize ,
14
+ instantiate_parametrized_tests ,
13
15
)
14
16
15
- from torchao .float8 .config import e4m3_dtype
16
17
from torchao .quantization import (
17
- FbgemmConfig ,
18
+ Float8DynamicActivationFloat8WeightConfig ,
19
+ PerRow ,
18
20
quantize_ ,
19
21
)
20
22
from torchao .quantization .utils import compute_error
24
26
)
25
27
26
28
29
+ FBGEMM_CONFIG = Float8DynamicActivationFloat8WeightConfig (
30
+ granularity = PerRow (), kernel = "fbgemm"
31
+ )
32
+ ATEN_CONFIG = Float8DynamicActivationFloat8WeightConfig (
33
+ granularity = PerRow (), kernel = "aten"
34
+ )
35
+
36
+
27
37
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
28
38
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
29
39
@unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
30
40
class TestFbgemmFp8Tensor (TestCase ):
31
41
def setUp (self ):
32
- self .config = FbgemmConfig (
33
- input_dtype = e4m3_dtype ,
34
- weight_dtype = e4m3_dtype ,
35
- output_dtype = torch .bfloat16 ,
36
- )
37
- self .bmm_config = FbgemmConfig (
38
- input_dtype = e4m3_dtype ,
39
- weight_dtype = e4m3_dtype ,
40
- output_dtype = torch .bfloat16 ,
41
- transpose_input = True ,
42
- )
43
42
self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
44
43
45
- def test_linear (self ):
44
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
45
+ def test_linear (self , config ):
46
46
dtype = torch .bfloat16
47
47
device = "cuda"
48
48
input = torch .randn (1 , 128 , dtype = dtype , device = device )
49
49
linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
50
50
original = linear (input )
51
- quantize_ (linear , self . config )
51
+ quantize_ (linear , config )
52
52
quantized = linear (input )
53
- self .assertTrue (compute_error (original , quantized ) > 20 )
53
+ sqnr = compute_error (original , quantized )
54
+ self .assertTrue (sqnr > 20 , f"sqnr: { sqnr } " )
54
55
55
- def test_slice (self ):
56
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
57
+ def test_slice (self , config ):
56
58
dtype = torch .bfloat16
57
59
device = "cuda"
58
60
dummy = torch .nn .Linear (256 , 256 , bias = False , dtype = dtype , device = device )
@@ -65,7 +67,7 @@ def test_slice(self):
65
67
dummy .weight .narrow (1 , 0 , 128 ), requires_grad = False
66
68
)
67
69
68
- quantize_ (dummy , self . config )
70
+ quantize_ (dummy , config )
69
71
weight1 = dummy .weight .narrow (0 , 0 , 64 )
70
72
weight2 = dummy .weight .narrow (1 , 0 , 128 )
71
73
self .assertEqual (weight1 .float8_data , dummy .weight .float8_data .narrow (0 , 0 , 64 ))
@@ -81,20 +83,23 @@ def test_slice(self):
81
83
res_ref = dummy1 (input )
82
84
dummy .weight = torch .nn .Parameter (weight1 , requires_grad = False )
83
85
res = dummy (input )
84
- assert compute_error (res , res_ref ) > 25
86
+ sqnr = compute_error (res , res_ref )
87
+ self .assertTrue (sqnr > 25 , f"sqnr: { sqnr } " )
85
88
86
89
input = torch .randn (2 , 128 , dtype = dtype , device = device )
87
90
res_ref = dummy2 (input )
88
91
dummy .weight = torch .nn .Parameter (weight2 , requires_grad = False )
89
92
res = dummy (input )
90
- assert compute_error (res , res_ref ) > 15
93
+ sqnr = compute_error (res , res_ref )
94
+ self .assertTrue (sqnr > 15 , f"sqnr: { sqnr } " )
91
95
92
- def test_slice_and_copy_ (self ):
96
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
97
+ def test_slice_and_copy_ (self , config ):
93
98
l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
94
99
l .weight = torch .nn .Parameter (
95
100
torch .zeros (1024 , 1024 , dtype = torch .bfloat16 , device = "cuda" )
96
101
)
97
- quantize_ (l , self . config )
102
+ quantize_ (l , config )
98
103
param = l .weight
99
104
param_data = param .data
100
105
param_data = param_data .narrow (0 , 0 , 512 )
@@ -104,7 +109,7 @@ def test_slice_and_copy_(self):
104
109
105
110
# dummy_l has random input (shouldn't be 0)
106
111
dummy_l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
107
- quantize_ (dummy_l , self . config )
112
+ quantize_ (dummy_l , config )
108
113
quantized = dummy_l .weight
109
114
quantized = quantized .narrow (0 , 0 , 512 )
110
115
@@ -113,7 +118,8 @@ def test_slice_and_copy_(self):
113
118
# making sure param.data is updated
114
119
assert param .data .float8_data [0 ][0 ] != orig_value
115
120
116
- def test_bmm (self ):
121
+ @parametrize ("config" , [FBGEMM_CONFIG ])
122
+ def test_bmm (self , config ):
117
123
class M (torch .nn .Module ):
118
124
def __init__ (self , weight ):
119
125
super ().__init__ ()
@@ -128,24 +134,80 @@ def forward(self, x):
128
134
weight = torch .randn (10 , 128 , 256 , dtype = dtype , device = device )
129
135
m = M (weight ).eval ()
130
136
original = m (input )
131
- quantize_ (m , self .bmm_config , filter_fn = lambda x , fqn : True )
137
+ # we need to transpose the weight first for bmm
138
+ m .weight = torch .nn .Parameter (m .weight .transpose (1 , 2 ).contiguous ())
139
+ quantize_ (m , config , filter_fn = lambda x , fqn : True )
132
140
quantized = m (input )
133
141
self .assertTrue (compute_error (original , quantized ) > 20 )
134
142
135
- def test_to_device (self ):
143
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
144
+ def test_to_device (self , config ):
136
145
for device in self .GPU_DEVICES :
137
146
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
138
- quantize_ (linear , self . config )
147
+ quantize_ (linear , config )
139
148
linear .to (device )
140
149
141
150
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
142
- quantize_ (linear , self . config )
151
+ quantize_ (linear , config )
143
152
linear .to (device = device )
144
153
145
154
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
146
- quantize_ (linear , self . config )
155
+ quantize_ (linear , config )
147
156
linear .to (device )
148
157
158
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
159
+ def test_cat (self , config ):
160
+ dtype = torch .bfloat16
161
+ device = "cuda"
162
+ # weight: (256, 128)
163
+ linear1 = torch .nn .Linear (128 , 256 , dtype = dtype )
164
+ # weight: (256, 128)
165
+ linear2 = torch .nn .Linear (128 , 256 , dtype = dtype )
166
+
167
+ cat_weight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
168
+ dummy1 = torch .nn .Linear (128 , 512 , bias = False , dtype = dtype , device = device )
169
+
170
+ dummy1 .weight = torch .nn .Parameter (cat_weight1 )
171
+ quantize_ (dummy1 , config )
172
+
173
+ quantize_ (linear1 , config )
174
+ quantize_ (linear2 , config )
175
+
176
+ cat_qweight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
177
+ self .assertTrue (cat_qweight1 .shape , (512 , 128 ))
178
+ self .assertEqual (dummy1 .weight .float8_data , cat_qweight1 .float8_data )
179
+ self .assertEqual (dummy1 .weight .scale , cat_qweight1 .scale )
180
+
181
+ # concat with dim == 1 is not really correct and will be fixed later
182
+ # when we support distributed checkpointing
183
+ cat_qweight2 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 1 )
184
+ self .assertTrue (cat_qweight2 .shape , (256 , 256 ))
185
+ ref_float8_data = torch .cat (
186
+ [linear1 .weight .float8_data , linear2 .weight .float8_data ], dim = 1
187
+ )
188
+ ref_scale = linear1 .weight .scale
189
+ self .assertEqual (cat_qweight2 .float8_data , ref_float8_data )
190
+ self .assertEqual (cat_qweight2 .scale , ref_scale )
191
+
192
+ @parametrize ("config" , [FBGEMM_CONFIG ])
193
+ def test_transpose (self , config ):
194
+ dtype = torch .bfloat16
195
+ device = "cuda"
196
+ # weight: (256, 128)
197
+ linear1 = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
198
+ quantize_ (linear1 , config )
199
+ linear1 .weight = torch .nn .Parameter (linear1 .weight .transpose (0 , 1 ).contiguous ())
200
+ linear1 .bias = torch .nn .Parameter (torch .randn (128 , dtype = dtype , device = device ))
201
+ self .assertTrue (linear1 .weight .shape , (128 , 256 ))
202
+
203
+ input = torch .randn (32 , 256 , dtype = dtype , device = device )
204
+ # make sure it runs
205
+ res = linear1 (input )
206
+ self .assertTrue (res .shape , (32 , 128 ))
207
+
208
+
209
+ instantiate_parametrized_tests (TestFbgemmFp8Tensor )
210
+
149
211
150
212
if __name__ == "__main__" :
151
213
run_tests ()
0 commit comments