39
39
40
40
def get_dummy_quant_config (
41
41
num_bits = 4 , strategy = None , group_size = None , actorder = None , symmetric = True
42
- ):
42
+ ) -> QuantizationConfig :
43
43
config_groups = {
44
44
"group_1" : QuantizationScheme (
45
45
targets = ["Linear" ],
@@ -81,9 +81,9 @@ def test_quant_format(shape):
81
81
quant_config = get_dummy_quant_config ()
82
82
83
83
compressor = PackedQuantizationCompressor (config = quant_config )
84
- quantized_modules_to_args = {"dummy" : quant_config .config_groups ["group_1" ]. weights }
84
+ quantized_modules_to_scheme = {"dummy" : quant_config .config_groups ["group_1" ]}
85
85
compressed_state_dict = compressor .compress (
86
- dense_state_dict , names_to_scheme = quantized_modules_to_args
86
+ dense_state_dict , names_to_scheme = quantized_modules_to_scheme
87
87
)
88
88
89
89
# compressed state_dict adds one entry for shape
@@ -156,25 +156,21 @@ def test_reload_match(tmp_path, num_bits):
156
156
157
157
# pack-compressor only needs the number of bits from the quant-args to decompress
158
158
# all other information is extracted from the compressed data directly
159
- names_to_scheme = {
160
- "dummy" : QuantizationArgs (num_bits = num_bits ),
161
- "dummy2" : QuantizationArgs (num_bits = num_bits ),
162
- }
163
159
quant_config = get_dummy_quant_config (num_bits , symmetric = False )
164
160
165
161
compressor = PackedQuantizationCompressor (config = quant_config )
166
- quantized_modules_to_args = {
167
- "dummy" : quant_config .config_groups ["group_1" ]. weights ,
168
- "dummy2" : quant_config .config_groups ["group_1" ]. weights ,
162
+ quantized_modules_to_scheme = {
163
+ "dummy" : quant_config .config_groups ["group_1" ],
164
+ "dummy2" : quant_config .config_groups ["group_1" ],
169
165
}
170
166
171
167
compressed_state_dict = compressor .compress (
172
- dense_state_dict , names_to_scheme = quantized_modules_to_args
168
+ dense_state_dict . copy () , names_to_scheme = quantized_modules_to_scheme
173
169
)
174
170
save_file (compressed_state_dict , tmp_path / "model.safetensors" )
175
171
176
172
reconstructed_dense_gen = compressor .decompress (
177
- tmp_path , names_to_scheme = names_to_scheme
173
+ tmp_path , names_to_scheme = quantized_modules_to_scheme
178
174
)
179
175
reconstructed_dense = {}
180
176
for name , value in reconstructed_dense_gen :
@@ -184,7 +180,7 @@ def test_reload_match(tmp_path, num_bits):
184
180
dense_state_dict ["dummy.weight" ],
185
181
scale = dense_state_dict ["dummy.weight_scale" ],
186
182
zero_point = dense_state_dict ["dummy.weight_zero_point" ],
187
- args = quantized_modules_to_args ["dummy" ],
183
+ args = quantized_modules_to_scheme ["dummy" ]. weights ,
188
184
)
189
185
assert torch .equal (
190
186
fake_quant_dummy , reconstructed_dense ["dummy.weight" ].to (torch .float32 )
@@ -194,7 +190,7 @@ def test_reload_match(tmp_path, num_bits):
194
190
dense_state_dict ["dummy2.weight" ],
195
191
scale = dense_state_dict ["dummy2.weight_scale" ],
196
192
zero_point = dense_state_dict ["dummy2.weight_zero_point" ],
197
- args = quantized_modules_to_args ["dummy2" ],
193
+ args = quantized_modules_to_scheme ["dummy2" ]. weights ,
198
194
)
199
195
assert torch .equal (
200
196
fake_quant_dummy2 , reconstructed_dense ["dummy2.weight" ].to (torch .float32 )
@@ -231,17 +227,17 @@ def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration):
231
227
232
228
# compress
233
229
compressor = PackedQuantizationCompressor (config = quant_config )
234
- quantized_modules_to_args = {
235
- "dummy" : quant_config .config_groups ["group_1" ]. weights ,
230
+ quantized_modules_to_scheme = {
231
+ "dummy" : quant_config .config_groups ["group_1" ],
236
232
}
237
233
compressed_state_dict = compressor .compress (
238
- model .state_dict (), names_to_scheme = quantized_modules_to_args
234
+ model .state_dict (), names_to_scheme = quantized_modules_to_scheme
239
235
)
240
236
save_file (compressed_state_dict , tmp_path / "model.safetensors" )
241
237
242
238
# decompress
243
239
reconstructed_dense_gen = compressor .decompress (
244
- tmp_path , names_to_scheme = quantized_modules_to_args
240
+ tmp_path , names_to_scheme = quantized_modules_to_scheme
245
241
)
246
242
reconstructed_dense = {}
247
243
for name , value in reconstructed_dense_gen :
@@ -252,7 +248,7 @@ def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration):
252
248
scale = model .dummy .weight_scale ,
253
249
zero_point = model .dummy .weight_zero_point ,
254
250
g_idx = getattr (model .dummy , "weight_g_idx" , None ),
255
- args = quantized_modules_to_args ["dummy" ],
251
+ args = quantized_modules_to_scheme ["dummy" ]. weights ,
256
252
)
257
253
assert torch .equal (fake_quant_dummy , reconstructed_dense ["dummy.weight" ])
258
254
0 commit comments