Skip to content

Commit ac27709

Browse files
committed
fix tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 9524c7f commit ac27709

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

tests/test_compressors/quantized_compressors/test_pack_quant.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
def get_dummy_quant_config(
4141
num_bits=4, strategy=None, group_size=None, actorder=None, symmetric=True
42-
):
42+
) -> QuantizationConfig:
4343
config_groups = {
4444
"group_1": QuantizationScheme(
4545
targets=["Linear"],
@@ -81,9 +81,9 @@ def test_quant_format(shape):
8181
quant_config = get_dummy_quant_config()
8282

8383
compressor = PackedQuantizationCompressor(config=quant_config)
84-
quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights}
84+
quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]}
8585
compressed_state_dict = compressor.compress(
86-
dense_state_dict, names_to_scheme=quantized_modules_to_args
86+
dense_state_dict, names_to_scheme=quantized_modules_to_scheme
8787
)
8888

8989
# compressed state_dict adds one entry for shape
@@ -156,25 +156,21 @@ def test_reload_match(tmp_path, num_bits):
156156

157157
# pack-compressor only needs the number of bits from the quant-args to decompress
158158
# all other information is extracted from the compressed data directly
159-
names_to_scheme = {
160-
"dummy": QuantizationArgs(num_bits=num_bits),
161-
"dummy2": QuantizationArgs(num_bits=num_bits),
162-
}
163159
quant_config = get_dummy_quant_config(num_bits, symmetric=False)
164160

165161
compressor = PackedQuantizationCompressor(config=quant_config)
166-
quantized_modules_to_args = {
167-
"dummy": quant_config.config_groups["group_1"].weights,
168-
"dummy2": quant_config.config_groups["group_1"].weights,
162+
quantized_modules_to_scheme = {
163+
"dummy": quant_config.config_groups["group_1"],
164+
"dummy2": quant_config.config_groups["group_1"],
169165
}
170166

171167
compressed_state_dict = compressor.compress(
172-
dense_state_dict, names_to_scheme=quantized_modules_to_args
168+
dense_state_dict.copy(), names_to_scheme=quantized_modules_to_scheme
173169
)
174170
save_file(compressed_state_dict, tmp_path / "model.safetensors")
175171

176172
reconstructed_dense_gen = compressor.decompress(
177-
tmp_path, names_to_scheme=names_to_scheme
173+
tmp_path, names_to_scheme=quantized_modules_to_scheme
178174
)
179175
reconstructed_dense = {}
180176
for name, value in reconstructed_dense_gen:
@@ -184,7 +180,7 @@ def test_reload_match(tmp_path, num_bits):
184180
dense_state_dict["dummy.weight"],
185181
scale=dense_state_dict["dummy.weight_scale"],
186182
zero_point=dense_state_dict["dummy.weight_zero_point"],
187-
args=quantized_modules_to_args["dummy"],
183+
args=quantized_modules_to_scheme["dummy"].weights,
188184
)
189185
assert torch.equal(
190186
fake_quant_dummy, reconstructed_dense["dummy.weight"].to(torch.float32)
@@ -194,7 +190,7 @@ def test_reload_match(tmp_path, num_bits):
194190
dense_state_dict["dummy2.weight"],
195191
scale=dense_state_dict["dummy2.weight_scale"],
196192
zero_point=dense_state_dict["dummy2.weight_zero_point"],
197-
args=quantized_modules_to_args["dummy2"],
193+
args=quantized_modules_to_scheme["dummy2"].weights,
198194
)
199195
assert torch.equal(
200196
fake_quant_dummy2, reconstructed_dense["dummy2.weight"].to(torch.float32)
@@ -231,17 +227,17 @@ def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration):
231227

232228
# compress
233229
compressor = PackedQuantizationCompressor(config=quant_config)
234-
quantized_modules_to_args = {
235-
"dummy": quant_config.config_groups["group_1"].weights,
230+
quantized_modules_to_scheme = {
231+
"dummy": quant_config.config_groups["group_1"],
236232
}
237233
compressed_state_dict = compressor.compress(
238-
model.state_dict(), names_to_scheme=quantized_modules_to_args
234+
model.state_dict(), names_to_scheme=quantized_modules_to_scheme
239235
)
240236
save_file(compressed_state_dict, tmp_path / "model.safetensors")
241237

242238
# decompress
243239
reconstructed_dense_gen = compressor.decompress(
244-
tmp_path, names_to_scheme=quantized_modules_to_args
240+
tmp_path, names_to_scheme=quantized_modules_to_scheme
245241
)
246242
reconstructed_dense = {}
247243
for name, value in reconstructed_dense_gen:
@@ -252,7 +248,7 @@ def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration):
252248
scale=model.dummy.weight_scale,
253249
zero_point=model.dummy.weight_zero_point,
254250
g_idx=getattr(model.dummy, "weight_g_idx", None),
255-
args=quantized_modules_to_args["dummy"],
251+
args=quantized_modules_to_scheme["dummy"].weights,
256252
)
257253
assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy.weight"])
258254

0 commit comments

Comments
 (0)