Skip to content

Commit 8ae9004

Browse files
committed
fix tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 700d4b6 commit 8ae9004

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

tests/test_compressors/quantized_compressors/test_fp8_quant.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def test_quant_format(strategy, group_size, sc, zp):
8484
quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size)
8585

8686
compressor = FloatQuantizationCompressor(config=quant_config)
87-
quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights}
87+
module_name_to_scheme = {"dummy": quant_config.config_groups["group_1"]}
8888
compressed_state_dict = compressor.compress(
89-
dense_state_dict, names_to_scheme=quantized_modules_to_args
89+
dense_state_dict, names_to_scheme=module_name_to_scheme
9090
)
9191

9292
# state_dict params should be the same, minus the zero_point if symmetric
@@ -140,15 +140,15 @@ def test_reload_match(
140140
)
141141

142142
compressor = FloatQuantizationCompressor(config=quant_config)
143-
quantized_modules_to_args = {
144-
"dummy": quant_config.config_groups["group_1"].weights,
143+
module_name_to_scheme = {
144+
"dummy": quant_config.config_groups["group_1"],
145145
}
146146
compressed_state_dict = compressor.compress(
147-
model.state_dict(), names_to_scheme=quantized_modules_to_args
147+
model.state_dict(), names_to_scheme=module_name_to_scheme
148148
)
149149
save_file(compressed_state_dict, tmp_path / "model.safetensors")
150150
reconstructed_dense_gen = compressor.decompress(
151-
tmp_path, names_to_scheme=quantized_modules_to_args
151+
tmp_path, names_to_scheme=module_name_to_scheme
152152
)
153153
reconstructed_dense = {}
154154
for name, value in reconstructed_dense_gen:
@@ -158,7 +158,7 @@ def test_reload_match(
158158
model.dummy.weight,
159159
scale=model.dummy.weight_scale,
160160
zero_point=model.dummy.weight_zero_point,
161-
args=quantized_modules_to_args["dummy"],
161+
args=module_name_to_scheme["dummy"].weight,
162162
)
163163
assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy"].get("weight"))
164164

tests/test_compressors/quantized_compressors/test_int_quant.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp):
7676
)
7777

7878
compressor = IntQuantizationCompressor(config=quant_config)
79-
quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights}
79+
quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]}
8080
compressed_state_dict = compressor.compress(
81-
dense_state_dict, names_to_scheme=quantized_modules_to_args
81+
dense_state_dict, names_to_scheme=quantized_modules_to_scheme
8282
)
8383

8484
# state_dict params should be the same, minus the zero_point if symmetric
@@ -124,16 +124,16 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path):
124124
quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size)
125125

126126
compressor = IntQuantizationCompressor(config=quant_config)
127-
quantized_modules_to_args = {
128-
"dummy": quant_config.config_groups["group_1"].weights,
129-
"dummy2": quant_config.config_groups["group_1"].weights,
127+
module_name_to_scheme = {
128+
"dummy": quant_config.config_groups["group_1"],
129+
"dummy2": quant_config.config_groups["group_1"],
130130
}
131131
compressed_state_dict = compressor.compress(
132-
dense_state_dict, names_to_scheme=quantized_modules_to_args
132+
dense_state_dict, names_to_scheme=module_name_to_scheme
133133
)
134134
save_file(compressed_state_dict, tmp_path / "model.safetensors")
135135
reconstructed_dense_gen = compressor.decompress(
136-
tmp_path, names_to_scheme=quantized_modules_to_args
136+
tmp_path, names_to_scheme=module_name_to_scheme
137137
)
138138
reconstructed_dense = {}
139139
for name, value in reconstructed_dense_gen:
@@ -143,7 +143,7 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path):
143143
dense_state_dict["dummy.weight"],
144144
scale=dense_state_dict["dummy.weight_scale"],
145145
zero_point=dense_state_dict["dummy.weight_zero_point"],
146-
args=quantized_modules_to_args["dummy"],
146+
args=module_name_to_scheme["dummy"].weights,
147147
)
148148
assert torch.equal(
149149
fake_quant_dummy, reconstructed_dense["dummy"].get("weight").to(torch.float32)
@@ -153,7 +153,7 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path):
153153
dense_state_dict["dummy2.weight"],
154154
scale=dense_state_dict["dummy2.weight_scale"],
155155
zero_point=dense_state_dict["dummy2.weight_zero_point"],
156-
args=quantized_modules_to_args["dummy2"],
156+
args=module_name_to_scheme["dummy2"].weight,
157157
)
158158
assert torch.equal(
159159
fake_quant_dummy2, reconstructed_dense["dummy2"].get("weight").to(torch.float32)

0 commit comments

Comments
 (0)