14
14
float8_weight_only ,
15
15
float8_dynamic_activation_float8_weight ,
16
16
)
17
+ from torchao .quantization .quant_api import (
18
+ float8_static_activation_float8_weight ,
19
+ )
20
+ from torchao .quantization .quant_primitives import choose_qparams_affine , MappingType
17
21
from torchao .quantization .observer import PerTensor , PerRow
18
22
from torchao .float8 .float8_utils import compute_error
19
23
import torch
@@ -50,7 +54,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
50
54
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
51
55
@unittest .skipIf (not is_cuda_8_9 , "Requires GPU with compute capability >= 8.9" )
52
56
@common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
53
- @common_utils .parametrize ("mode" , ["dynamic" , "weight-only" ])
57
+ @common_utils .parametrize ("mode" , ["dynamic" , "weight-only" , "static" ])
54
58
@common_utils .parametrize ("compile" , [True , False ])
55
59
@common_utils .parametrize (
56
60
"granularity" , [PerTensor (), PerRow ()] if is_H100 else [PerTensor ()]
@@ -60,45 +64,57 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
60
64
"sizes" ,
61
65
[
62
66
((128 ,), 256 , 128 ),
63
- ((256 ,), 512 , 256 ),
64
- ((64 ,), 128 , 64 ),
65
67
((32 , 128 ), 64 , 256 ),
66
- ((64 , 256 ), 512 , 128 ),
67
68
],
68
69
)
69
70
def test_fp8_linear_variants (
70
71
self , dtype : torch .dtype , mode : str , compile : bool , sizes : Tuple , granularity
71
72
):
72
- raises = (
73
- isinstance (granularity , PerRow )
74
- and mode == "dynamic"
75
- and dtype != torch .bfloat16
76
- )
77
- context = (
78
- nullcontext ()
79
- if not raises
80
- else pytest .raises (
81
- AssertionError ,
82
- match = "PerRow quantization only works for bfloat16 precision" ,
83
- )
73
+ error_message = None
74
+ if isinstance (granularity , PerRow ):
75
+ if mode == "dynamic" and dtype != torch .bfloat16 :
76
+ error_message = "PerRow quantization only works for bfloat16 precision"
77
+ elif mode == "static" :
78
+ error_message = (
79
+ "Static quantization only supports PerTensor granularity"
80
+ )
81
+
82
+ error_context = (
83
+ pytest .raises (AssertionError , match = error_message )
84
+ if error_message
85
+ else nullcontext ()
84
86
)
85
- with context :
87
+
88
+ with error_context :
86
89
M , N , K = sizes
87
90
input_tensor = torch .randn (* M , K , dtype = dtype , device = "cuda" )
88
-
91
+ # Get a "reasonable" scale for the input tensor even though
92
+ # we use the same scale for multiple activations
93
+ scale , _ = choose_qparams_affine (
94
+ input_tensor ,
95
+ MappingType .SYMMETRIC ,
96
+ input_tensor .shape ,
97
+ torch .float8_e4m3fn ,
98
+ scale_dtype = torch .float32 ,
99
+ )
89
100
mode_map = {
90
101
"dynamic" : partial (
91
102
float8_dynamic_activation_float8_weight , granularity = granularity
92
103
),
93
104
"weight-only" : float8_weight_only ,
105
+ "static" : partial (
106
+ float8_static_activation_float8_weight ,
107
+ scale = scale ,
108
+ granularity = granularity ,
109
+ ),
94
110
}
95
111
96
112
# Create a linear layer with bfloat16 dtype
97
113
model = ToyLinearModel (K , N ).eval ().to (dtype ).to ("cuda" )
98
114
99
115
quantized_model = copy .deepcopy (model )
100
116
factory = mode_map [mode ]()
101
- quantize_ (model , factory )
117
+ quantize_ (quantized_model , factory )
102
118
103
119
if compile :
104
120
quantized_model = torch .compile (quantized_model , fullgraph = True )
@@ -145,14 +161,23 @@ def test_per_row_with_float32(self):
145
161
146
162
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
147
163
@unittest .skipIf (not is_cuda_8_9 , "Requires GPU with compute capability >= 8.9" )
148
- @common_utils .parametrize ("mode" , ["dynamic" , "weight-only" ])
164
+ @common_utils .parametrize ("mode" , ["dynamic" , "weight-only" , "static" ])
149
165
def test_serialization (self , mode : str ):
150
166
# Create and quantize the model
151
167
model = ToyLinearModel (16 , 32 ).to (device = "cuda" )
152
- if mode == "dynamic" :
153
- factory = float8_dynamic_activation_float8_weight ()
154
- else :
155
- factory = float8_weight_only ()
168
+
169
+ mode_map = {
170
+ "dynamic" : partial (
171
+ float8_dynamic_activation_float8_weight , granularity = PerTensor ()
172
+ ),
173
+ "weight-only" : float8_weight_only ,
174
+ "static" : partial (
175
+ float8_static_activation_float8_weight ,
176
+ scale = torch .tensor (1.0 , dtype = torch .float32 , device = "cuda" ),
177
+ granularity = PerTensor (),
178
+ ),
179
+ }
180
+ factory = mode_map [mode ]()
156
181
quantize_ (model , factory )
157
182
158
183
# Save the state dict to an in-memory buffer
@@ -163,46 +188,50 @@ def test_serialization(self, mode: str):
163
188
buffer .seek (0 )
164
189
165
190
# Load the state dict from the buffer
166
- loaded_state_dict = torch .load (buffer )
191
+ weights_only_load = True
192
+ if mode == "dynamic" :
193
+ # TODO will fix in followup
194
+ weights_only_load = False
195
+
196
+ loaded_state_dict = torch .load (buffer , weights_only = weights_only_load )
167
197
168
198
# Create a new model and load the state dict
169
199
with torch .device ("meta" ):
170
200
new_model = ToyLinearModel (16 , 32 )
201
+ if mode == "static" :
202
+ quantize_ (new_model , factory )
171
203
new_model .load_state_dict (loaded_state_dict , assign = True )
172
204
173
205
# Compare the original and loaded models
174
- if mode == "weight-only" :
175
- model_weight_1 = model .linear1 .weight .layout_tensor .float8_data .to (
176
- torch .float32
177
- )
178
- new_model_weight_1 = new_model .linear1 .weight .layout_tensor .float8_data .to (
179
- torch .float32
180
- )
181
-
182
- model_weight_2 = model .linear2 .weight .layout_tensor .float8_data .to (
183
- torch .float32
184
- )
185
- new_model_weight_2 = new_model .linear2 .weight .layout_tensor .float8_data .to (
186
- torch .float32
187
- )
188
-
189
- else :
190
- model_weight_1 = model .linear1 .weight .original_weight_tensor .layout_tensor .float8_data .to (
191
- torch .float32
192
- )
193
- new_model_weight_1 = new_model .linear1 .weight .original_weight_tensor .layout_tensor .float8_data .to (
194
- torch .float32
195
- )
196
-
197
- model_weight_2 = model .linear2 .weight .original_weight_tensor .layout_tensor .float8_data .to (
198
- torch .float32
199
- )
200
- new_model_weight_2 = new_model .linear2 .weight .original_weight_tensor .layout_tensor .float8_data .to (
201
- torch .float32
202
- )
203
-
204
- assert torch .allclose (model_weight_1 , new_model_weight_1 )
205
- assert torch .allclose (model_weight_2 , new_model_weight_2 )
206
+ for layer_name in ["linear1" , "linear2" ]:
207
+ original_layer = getattr (model , layer_name )
208
+ new_layer = getattr (new_model , layer_name )
209
+
210
+ # Compare weights
211
+ if mode == "weight-only" :
212
+ original_weight = original_layer .weight .layout_tensor .float8_data .to (
213
+ torch .float32
214
+ )
215
+ new_weight = new_layer .weight .layout_tensor .float8_data .to (
216
+ torch .float32
217
+ )
218
+ else :
219
+ original_weight = original_layer .weight .original_weight_tensor .layout_tensor .float8_data .to (
220
+ torch .float32
221
+ )
222
+ new_weight = new_layer .weight .original_weight_tensor .layout_tensor .float8_data .to (
223
+ torch .float32
224
+ )
225
+
226
+ assert torch .allclose (
227
+ original_weight , new_weight
228
+ ), f"Weights do not match for { layer_name } "
229
+
230
+ # Compare scales
231
+ if hasattr (original_layer .weight , "scale" ):
232
+ assert torch .allclose (
233
+ original_layer .weight .scale , new_layer .weight .scale
234
+ ), f"Scales do not match for { layer_name } "
206
235
207
236
208
237
common_utils .instantiate_parametrized_tests (TestAffineQuantizedFloat8Compile )
0 commit comments