12
12
run_tests ,
13
13
)
14
14
15
- from torchao .float8 .config import e4m3_dtype
16
15
from torchao .quantization import (
17
16
FbgemmConfig ,
18
17
quantize_ ,
29
28
@unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
30
29
class TestFbgemmFp8Tensor (TestCase ):
31
30
def setUp (self ):
31
+ self .e4m3_dtype = torch .float8_e4m3fn
32
32
self .config = FbgemmConfig (
33
- input_dtype = e4m3_dtype ,
34
- weight_dtype = e4m3_dtype ,
33
+ input_dtype = self . e4m3_dtype ,
34
+ weight_dtype = self . e4m3_dtype ,
35
35
output_dtype = torch .bfloat16 ,
36
36
)
37
- self .bmm_config = FbgemmConfig (
38
- input_dtype = e4m3_dtype ,
39
- weight_dtype = e4m3_dtype ,
40
- output_dtype = torch .bfloat16 ,
41
- transpose_input = True ,
42
- )
43
37
self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
44
38
45
39
def test_linear (self ):
@@ -128,7 +122,9 @@ def forward(self, x):
128
122
weight = torch .randn (10 , 128 , 256 , dtype = dtype , device = device )
129
123
m = M (weight ).eval ()
130
124
original = m (input )
131
- quantize_ (m , self .bmm_config , filter_fn = lambda x , fqn : True )
125
+ # we need to transpose the weight first for bmm
126
+ m .weight = torch .nn .Parameter (m .weight .transpose (1 , 2 ).contiguous ())
127
+ quantize_ (m , self .config , filter_fn = lambda x , fqn : True )
132
128
quantized = m (input )
133
129
self .assertTrue (compute_error (original , quantized ) > 20 )
134
130
@@ -146,6 +142,54 @@ def test_to_device(self):
146
142
quantize_ (linear , self .config )
147
143
linear .to (device )
148
144
145
+ def test_cat (self ):
146
+ dtype = torch .bfloat16
147
+ device = "cuda"
148
+ # weight: (256, 128)
149
+ linear1 = torch .nn .Linear (128 , 256 , dtype = dtype )
150
+ # weight: (256, 128)
151
+ linear2 = torch .nn .Linear (128 , 256 , dtype = dtype )
152
+
153
+ cat_weight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
154
+ dummy1 = torch .nn .Linear (128 , 512 , bias = False , dtype = dtype , device = device )
155
+
156
+ dummy1 .weight = torch .nn .Parameter (cat_weight1 )
157
+ quantize_ (dummy1 , self .config )
158
+
159
+ quantize_ (linear1 , self .config )
160
+ quantize_ (linear2 , self .config )
161
+
162
+ cat_qweight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
163
+ self .assertTrue (cat_qweight1 .shape , (512 , 128 ))
164
+ self .assertEqual (dummy1 .weight .float8_data , cat_qweight1 .float8_data )
165
+ self .assertEqual (dummy1 .weight .scale , cat_qweight1 .scale )
166
+
167
+ # concat with dim == 1 is not really correct and will be fixed later
168
+ # when we support distributed checkpointing
169
+ cat_qweight2 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 1 )
170
+ self .assertTrue (cat_qweight2 .shape , (256 , 256 ))
171
+ ref_float8_data = torch .cat (
172
+ [linear1 .weight .float8_data , linear2 .weight .float8_data ], dim = 1
173
+ )
174
+ ref_scale = linear1 .weight .scale
175
+ self .assertEqual (cat_qweight2 .float8_data , ref_float8_data )
176
+ self .assertEqual (cat_qweight2 .scale , ref_scale )
177
+
178
+ def test_transpose (self ):
179
+ dtype = torch .bfloat16
180
+ device = "cuda"
181
+ # weight: (256, 128)
182
+ linear1 = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
183
+ quantize_ (linear1 , self .config )
184
+ linear1 .weight = torch .nn .Parameter (linear1 .weight .transpose (0 , 1 ).contiguous ())
185
+ linear1 .bias = torch .nn .Parameter (torch .randn (128 , dtype = dtype , device = device ))
186
+ self .assertTrue (linear1 .weight .shape , (128 , 256 ))
187
+
188
+ input = torch .randn (32 , 256 , dtype = dtype , device = device )
189
+ # make sure it runs
190
+ res = linear1 (input )
191
+ self .assertTrue (res .shape , (32 , 128 ))
192
+
149
193
150
194
if __name__ == "__main__" :
151
195
run_tests ()
0 commit comments