9
9
import torch
10
10
from torch .testing ._internal .common_utils import (
11
11
TestCase ,
12
+ instantiate_parametrized_tests ,
13
+ parametrize ,
12
14
run_tests ,
13
15
)
14
16
15
- from torchao .float8 .config import e4m3_dtype
16
17
from torchao .quantization import (
17
- FbgemmConfig ,
18
+ Float8DynamicActivationFloat8WeightConfig ,
19
+ PerRow ,
18
20
quantize_ ,
19
21
)
20
22
from torchao .quantization .utils import compute_error
23
25
is_sm_at_least_90 ,
24
26
)
25
27
28
+ FBGEMM_CONFIG = Float8DynamicActivationFloat8WeightConfig (
29
+ granularity = PerRow (), kernel = "fbgemm"
30
+ )
31
+ ATEN_CONFIG = Float8DynamicActivationFloat8WeightConfig (
32
+ granularity = PerRow (), kernel = "aten"
33
+ )
34
+
26
35
27
36
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
28
37
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
29
38
@unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
30
39
class TestFbgemmFp8Tensor (TestCase ):
31
40
def setUp (self ):
32
- self .config = FbgemmConfig (
33
- input_dtype = e4m3_dtype ,
34
- weight_dtype = e4m3_dtype ,
35
- output_dtype = torch .bfloat16 ,
36
- )
37
- self .bmm_config = FbgemmConfig (
38
- input_dtype = e4m3_dtype ,
39
- weight_dtype = e4m3_dtype ,
40
- output_dtype = torch .bfloat16 ,
41
- transpose_input = True ,
42
- )
43
41
self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
44
42
45
- def test_linear (self ):
43
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
44
+ def test_linear (self , config ):
46
45
dtype = torch .bfloat16
47
46
device = "cuda"
48
47
input = torch .randn (1 , 128 , dtype = dtype , device = device )
49
48
linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
50
49
original = linear (input )
51
- quantize_ (linear , self . config )
50
+ quantize_ (linear , config )
52
51
quantized = linear (input )
53
- self .assertTrue (compute_error (original , quantized ) > 20 )
52
+ sqnr = compute_error (original , quantized )
53
+ self .assertTrue (sqnr > 20 , f"sqnr: { sqnr } " )
54
54
55
- def test_slice (self ):
55
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
56
+ def test_slice (self , config ):
56
57
dtype = torch .bfloat16
57
58
device = "cuda"
58
59
dummy = torch .nn .Linear (256 , 256 , bias = False , dtype = dtype , device = device )
@@ -65,7 +66,7 @@ def test_slice(self):
65
66
dummy .weight .narrow (1 , 0 , 128 ), requires_grad = False
66
67
)
67
68
68
- quantize_ (dummy , self . config )
69
+ quantize_ (dummy , config )
69
70
weight1 = dummy .weight .narrow (0 , 0 , 64 )
70
71
weight2 = dummy .weight .narrow (1 , 0 , 128 )
71
72
self .assertEqual (weight1 .float8_data , dummy .weight .float8_data .narrow (0 , 0 , 64 ))
@@ -81,20 +82,23 @@ def test_slice(self):
81
82
res_ref = dummy1 (input )
82
83
dummy .weight = torch .nn .Parameter (weight1 , requires_grad = False )
83
84
res = dummy (input )
84
- assert compute_error (res , res_ref ) > 25
85
+ sqnr = compute_error (res , res_ref )
86
+ self .assertTrue (sqnr > 25 , f"sqnr: { sqnr } " )
85
87
86
88
input = torch .randn (2 , 128 , dtype = dtype , device = device )
87
89
res_ref = dummy2 (input )
88
90
dummy .weight = torch .nn .Parameter (weight2 , requires_grad = False )
89
91
res = dummy (input )
90
- assert compute_error (res , res_ref ) > 15
92
+ sqnr = compute_error (res , res_ref )
93
+ self .assertTrue (sqnr > 15 , f"sqnr: { sqnr } " )
91
94
92
- def test_slice_and_copy_ (self ):
95
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
96
+ def test_slice_and_copy_ (self , config ):
93
97
l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
94
98
l .weight = torch .nn .Parameter (
95
99
torch .zeros (1024 , 1024 , dtype = torch .bfloat16 , device = "cuda" )
96
100
)
97
- quantize_ (l , self . config )
101
+ quantize_ (l , config )
98
102
param = l .weight
99
103
param_data = param .data
100
104
param_data = param_data .narrow (0 , 0 , 512 )
@@ -104,7 +108,7 @@ def test_slice_and_copy_(self):
104
108
105
109
# dummy_l has random input (shouldn't be 0)
106
110
dummy_l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
107
- quantize_ (dummy_l , self . config )
111
+ quantize_ (dummy_l , config )
108
112
quantized = dummy_l .weight
109
113
quantized = quantized .narrow (0 , 0 , 512 )
110
114
@@ -113,7 +117,8 @@ def test_slice_and_copy_(self):
113
117
# making sure param.data is updated
114
118
assert param .data .float8_data [0 ][0 ] != orig_value
115
119
116
- def test_bmm (self ):
120
+ @parametrize ("config" , [FBGEMM_CONFIG ])
121
+ def test_bmm (self , config ):
117
122
class M (torch .nn .Module ):
118
123
def __init__ (self , weight ):
119
124
super ().__init__ ()
@@ -128,24 +133,80 @@ def forward(self, x):
128
133
weight = torch .randn (10 , 128 , 256 , dtype = dtype , device = device )
129
134
m = M (weight ).eval ()
130
135
original = m (input )
131
- quantize_ (m , self .bmm_config , filter_fn = lambda x , fqn : True )
136
+ # we need to transpose the weight first for bmm
137
+ m .weight = torch .nn .Parameter (m .weight .transpose (1 , 2 ).contiguous ())
138
+ quantize_ (m , config , filter_fn = lambda x , fqn : True )
132
139
quantized = m (input )
133
140
self .assertTrue (compute_error (original , quantized ) > 20 )
134
141
135
- def test_to_device (self ):
142
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
143
+ def test_to_device (self , config ):
136
144
for device in self .GPU_DEVICES :
137
145
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
138
- quantize_ (linear , self . config )
146
+ quantize_ (linear , config )
139
147
linear .to (device )
140
148
141
149
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
142
- quantize_ (linear , self . config )
150
+ quantize_ (linear , config )
143
151
linear .to (device = device )
144
152
145
153
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
146
- quantize_ (linear , self . config )
154
+ quantize_ (linear , config )
147
155
linear .to (device )
148
156
157
+ @parametrize ("config" , [FBGEMM_CONFIG , ATEN_CONFIG ])
158
+ def test_cat (self , config ):
159
+ dtype = torch .bfloat16
160
+ device = "cuda"
161
+ # weight: (256, 128)
162
+ linear1 = torch .nn .Linear (128 , 256 , dtype = dtype )
163
+ # weight: (256, 128)
164
+ linear2 = torch .nn .Linear (128 , 256 , dtype = dtype )
165
+
166
+ cat_weight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
167
+ dummy1 = torch .nn .Linear (128 , 512 , bias = False , dtype = dtype , device = device )
168
+
169
+ dummy1 .weight = torch .nn .Parameter (cat_weight1 )
170
+ quantize_ (dummy1 , config )
171
+
172
+ quantize_ (linear1 , config )
173
+ quantize_ (linear2 , config )
174
+
175
+ cat_qweight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
176
+ self .assertTrue (cat_qweight1 .shape , (512 , 128 ))
177
+ self .assertEqual (dummy1 .weight .float8_data , cat_qweight1 .float8_data )
178
+ self .assertEqual (dummy1 .weight .scale , cat_qweight1 .scale )
179
+
180
+ # concat with dim == 1 is not really correct and will be fixed later
181
+ # when we support distributed checkpointing
182
+ cat_qweight2 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 1 )
183
+ self .assertTrue (cat_qweight2 .shape , (256 , 256 ))
184
+ ref_float8_data = torch .cat (
185
+ [linear1 .weight .float8_data , linear2 .weight .float8_data ], dim = 1
186
+ )
187
+ ref_scale = linear1 .weight .scale
188
+ self .assertEqual (cat_qweight2 .float8_data , ref_float8_data )
189
+ self .assertEqual (cat_qweight2 .scale , ref_scale )
190
+
191
+ @parametrize ("config" , [FBGEMM_CONFIG ])
192
+ def test_transpose (self , config ):
193
+ dtype = torch .bfloat16
194
+ device = "cuda"
195
+ # weight: (256, 128)
196
+ linear1 = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
197
+ quantize_ (linear1 , config )
198
+ linear1 .weight = torch .nn .Parameter (linear1 .weight .transpose (0 , 1 ).contiguous ())
199
+ linear1 .bias = torch .nn .Parameter (torch .randn (128 , dtype = dtype , device = device ))
200
+ self .assertTrue (linear1 .weight .shape , (128 , 256 ))
201
+
202
+ input = torch .randn (32 , 256 , dtype = dtype , device = device )
203
+ # make sure it runs
204
+ res = linear1 (input )
205
+ self .assertTrue (res .shape , (32 , 128 ))
206
+
207
+
208
+ instantiate_parametrized_tests (TestFbgemmFp8Tensor )
209
+
149
210
150
211
if __name__ == "__main__" :
151
212
run_tests ()
0 commit comments