10
10
from torch .testing ._internal .common_utils import (
11
11
TestCase ,
12
12
run_tests ,
13
+ parametrize ,
14
+ instantiate_parametrized_tests ,
13
15
)
14
16
15
- from torchao .float8 .config import e4m3_dtype
16
17
from torchao .quantization import (
17
- FbgemmConfig ,
18
+ Float8DynamicActivationFloat8WeightConfig ,
19
+ PerRow ,
18
20
quantize_ ,
19
21
)
20
22
from torchao .quantization .utils import compute_error
24
26
)
25
27
26
28
29
+
30
+ FBGEMM_CONFIG = Float8DynamicActivationFloat8WeightConfig (
31
+ granularity = PerRow (),
32
+ kernel = "fbgemm"
33
+ )
34
+ ATEN_CONFIG = Float8DynamicActivationFloat8WeightConfig (
35
+ granularity = PerRow (),
36
+ kernel = "aten"
37
+ )
38
+
39
+
27
40
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
28
41
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
29
42
@unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
30
43
class TestFbgemmFp8Tensor (TestCase ):
31
44
def setUp (self ):
32
- self .config = FbgemmConfig (
33
- input_dtype = e4m3_dtype ,
34
- weight_dtype = e4m3_dtype ,
35
- output_dtype = torch .bfloat16 ,
36
- )
37
- self .bmm_config = FbgemmConfig (
38
- input_dtype = e4m3_dtype ,
39
- weight_dtype = e4m3_dtype ,
40
- output_dtype = torch .bfloat16 ,
41
- transpose_input = True ,
42
- )
43
45
self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
44
46
45
- def test_linear (self ):
47
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
48
+ def test_linear (self , config ):
46
49
dtype = torch .bfloat16
47
50
device = "cuda"
48
51
input = torch .randn (1 , 128 , dtype = dtype , device = device )
49
52
linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
50
53
original = linear (input )
51
- quantize_ (linear , self . config )
54
+ quantize_ (linear , config )
52
55
quantized = linear (input )
53
- self .assertTrue (compute_error (original , quantized ) > 20 )
56
+ sqnr = compute_error (original , quantized )
57
+ self .assertTrue (sqnr > 20 , f"sqnr: { sqnr } " )
54
58
55
- def test_slice (self ):
59
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
60
+ def test_slice (self , config ):
56
61
dtype = torch .bfloat16
57
62
device = "cuda"
58
63
dummy = torch .nn .Linear (256 , 256 , bias = False , dtype = dtype , device = device )
@@ -65,7 +70,7 @@ def test_slice(self):
65
70
dummy .weight .narrow (1 , 0 , 128 ), requires_grad = False
66
71
)
67
72
68
- quantize_ (dummy , self . config )
73
+ quantize_ (dummy , config )
69
74
weight1 = dummy .weight .narrow (0 , 0 , 64 )
70
75
weight2 = dummy .weight .narrow (1 , 0 , 128 )
71
76
self .assertEqual (weight1 .float8_data , dummy .weight .float8_data .narrow (0 , 0 , 64 ))
@@ -81,20 +86,23 @@ def test_slice(self):
81
86
res_ref = dummy1 (input )
82
87
dummy .weight = torch .nn .Parameter (weight1 , requires_grad = False )
83
88
res = dummy (input )
84
- assert compute_error (res , res_ref ) > 25
89
+ sqnr = compute_error (res , res_ref )
90
+ self .assertTrue (sqnr > 25 , f"sqnr: { sqnr } " )
85
91
86
92
input = torch .randn (2 , 128 , dtype = dtype , device = device )
87
93
res_ref = dummy2 (input )
88
94
dummy .weight = torch .nn .Parameter (weight2 , requires_grad = False )
89
95
res = dummy (input )
90
- assert compute_error (res , res_ref ) > 15
96
+ sqnr = compute_error (res , res_ref )
97
+ self .assertTrue (sqnr > 15 , f"sqnr: { sqnr } " )
91
98
92
- def test_slice_and_copy_ (self ):
99
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
100
+ def test_slice_and_copy_ (self , config ):
93
101
l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
94
102
l .weight = torch .nn .Parameter (
95
103
torch .zeros (1024 , 1024 , dtype = torch .bfloat16 , device = "cuda" )
96
104
)
97
- quantize_ (l , self . config )
105
+ quantize_ (l , config )
98
106
param = l .weight
99
107
param_data = param .data
100
108
param_data = param_data .narrow (0 , 0 , 512 )
@@ -104,7 +112,7 @@ def test_slice_and_copy_(self):
104
112
105
113
# dummy_l has random input (shouldn't be 0)
106
114
dummy_l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
107
- quantize_ (dummy_l , self . config )
115
+ quantize_ (dummy_l , config )
108
116
quantized = dummy_l .weight
109
117
quantized = quantized .narrow (0 , 0 , 512 )
110
118
@@ -113,7 +121,8 @@ def test_slice_and_copy_(self):
113
121
# making sure param.data is updated
114
122
assert param .data .float8_data [0 ][0 ] != orig_value
115
123
116
- def test_bmm (self ):
124
+ @parametrize ("config" , [FBGEMM_CONFIG ])
125
+ def test_bmm (self , config ):
117
126
class M (torch .nn .Module ):
118
127
def __init__ (self , weight ):
119
128
super ().__init__ ()
@@ -128,24 +137,80 @@ def forward(self, x):
128
137
weight = torch .randn (10 , 128 , 256 , dtype = dtype , device = device )
129
138
m = M (weight ).eval ()
130
139
original = m (input )
131
- quantize_ (m , self .bmm_config , filter_fn = lambda x , fqn : True )
140
+ # we need to transpose the weight first for bmm
141
+ m .weight = torch .nn .Parameter (m .weight .transpose (1 , 2 ).contiguous ())
142
+ quantize_ (m , config , filter_fn = lambda x , fqn : True )
132
143
quantized = m (input )
133
144
self .assertTrue (compute_error (original , quantized ) > 20 )
134
145
135
- def test_to_device (self ):
146
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
147
+ def test_to_device (self , config ):
136
148
for device in self .GPU_DEVICES :
137
149
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
138
- quantize_ (linear , self . config )
150
+ quantize_ (linear , config )
139
151
linear .to (device )
140
152
141
153
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
142
- quantize_ (linear , self . config )
154
+ quantize_ (linear , config )
143
155
linear .to (device = device )
144
156
145
157
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
146
- quantize_ (linear , self . config )
158
+ quantize_ (linear , config )
147
159
linear .to (device )
148
160
161
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
162
+ def test_cat (self , config ):
163
+ dtype = torch .bfloat16
164
+ device = "cuda"
165
+ # weight: (256, 128)
166
+ linear1 = torch .nn .Linear (128 , 256 , dtype = dtype )
167
+ # weight: (256, 128)
168
+ linear2 = torch .nn .Linear (128 , 256 , dtype = dtype )
169
+
170
+ cat_weight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
171
+ dummy1 = torch .nn .Linear (128 , 512 , bias = False , dtype = dtype , device = device )
172
+
173
+ dummy1 .weight = torch .nn .Parameter (cat_weight1 )
174
+ quantize_ (dummy1 , config )
175
+
176
+ quantize_ (linear1 , config )
177
+ quantize_ (linear2 , config )
178
+
179
+ cat_qweight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
180
+ self .assertTrue (cat_qweight1 .shape , (512 , 128 ))
181
+ self .assertEqual (dummy1 .weight .float8_data , cat_qweight1 .float8_data )
182
+ self .assertEqual (dummy1 .weight .scale , cat_qweight1 .scale )
183
+
184
+ # concat with dim == 1 is not really correct and will be fixed later
185
+ # when we support distributed checkpointing
186
+ cat_qweight2 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 1 )
187
+ self .assertTrue (cat_qweight2 .shape , (256 , 256 ))
188
+ ref_float8_data = torch .cat (
189
+ [linear1 .weight .float8_data , linear2 .weight .float8_data ], dim = 1
190
+ )
191
+ ref_scale = linear1 .weight .scale
192
+ self .assertEqual (cat_qweight2 .float8_data , ref_float8_data )
193
+ self .assertEqual (cat_qweight2 .scale , ref_scale )
194
+
195
+ @parametrize ("config" , [FBGEMM_CONFIG ])
196
+ def test_transpose (self , config ):
197
+ dtype = torch .bfloat16
198
+ device = "cuda"
199
+ # weight: (256, 128)
200
+ linear1 = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
201
+ quantize_ (linear1 , config )
202
+ linear1 .weight = torch .nn .Parameter (linear1 .weight .transpose (0 , 1 ).contiguous ())
203
+ linear1 .bias = torch .nn .Parameter (torch .randn (128 , dtype = dtype , device = device ))
204
+ self .assertTrue (linear1 .weight .shape , (128 , 256 ))
205
+
206
+ input = torch .randn (32 , 256 , dtype = dtype , device = device )
207
+ # make sure it runs
208
+ res = linear1 (input )
209
+ self .assertTrue (res .shape , (32 , 128 ))
210
+
211
+
212
+ instantiate_parametrized_tests (TestFbgemmFp8Tensor )
213
+
149
214
150
215
if __name__ == "__main__" :
151
216
run_tests ()
0 commit comments