@@ -76,9 +76,9 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp):
76
76
)
77
77
78
78
compressor = IntQuantizationCompressor (config = quant_config )
79
- quantized_modules_to_args = {"dummy" : quant_config .config_groups ["group_1" ]. weights }
79
+ quantized_modules_to_scheme = {"dummy" : quant_config .config_groups ["group_1" ]}
80
80
compressed_state_dict = compressor .compress (
81
- dense_state_dict , names_to_scheme = quantized_modules_to_args
81
+ dense_state_dict , names_to_scheme = quantized_modules_to_scheme
82
82
)
83
83
84
84
# state_dict params should be the same, minus the zero_point if symmetric
@@ -124,16 +124,16 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path):
124
124
quant_config = get_dummy_quant_config (strategy = strategy , group_size = group_size )
125
125
126
126
compressor = IntQuantizationCompressor (config = quant_config )
127
- quantized_modules_to_args = {
128
- "dummy" : quant_config .config_groups ["group_1" ]. weights ,
129
- "dummy2" : quant_config .config_groups ["group_1" ]. weights ,
127
+ module_name_to_scheme = {
128
+ "dummy" : quant_config .config_groups ["group_1" ],
129
+ "dummy2" : quant_config .config_groups ["group_1" ],
130
130
}
131
131
compressed_state_dict = compressor .compress (
132
- dense_state_dict , names_to_scheme = quantized_modules_to_args
132
+ dense_state_dict , names_to_scheme = module_name_to_scheme
133
133
)
134
134
save_file (compressed_state_dict , tmp_path / "model.safetensors" )
135
135
reconstructed_dense_gen = compressor .decompress (
136
- tmp_path , names_to_scheme = quantized_modules_to_args
136
+ tmp_path , names_to_scheme = module_name_to_scheme
137
137
)
138
138
reconstructed_dense = {}
139
139
for name , value in reconstructed_dense_gen :
@@ -143,7 +143,7 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path):
143
143
dense_state_dict ["dummy.weight" ],
144
144
scale = dense_state_dict ["dummy.weight_scale" ],
145
145
zero_point = dense_state_dict ["dummy.weight_zero_point" ],
146
- args = quantized_modules_to_args ["dummy" ],
146
+ args = module_name_to_scheme ["dummy" ]. weights ,
147
147
)
148
148
assert torch .equal (
149
149
fake_quant_dummy , reconstructed_dense ["dummy" ].get ("weight" ).to (torch .float32 )
@@ -153,7 +153,7 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path):
153
153
dense_state_dict ["dummy2.weight" ],
154
154
scale = dense_state_dict ["dummy2.weight_scale" ],
155
155
zero_point = dense_state_dict ["dummy2.weight_zero_point" ],
156
- args = quantized_modules_to_args ["dummy2" ],
156
+ args = module_name_to_scheme ["dummy2" ]. weight ,
157
157
)
158
158
assert torch .equal (
159
159
fake_quant_dummy2 , reconstructed_dense ["dummy2" ].get ("weight" ).to (torch .float32 )
0 commit comments