Skip to content

Commit e554fba

Browse files
authored
[Decompression] Keep unused parameters when decompressing from memory (#340)
* keep unused during decompression Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstring and typehint Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 56cf39c commit e554fba

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -462,18 +462,13 @@ def decompress_model(self, model: Module):
462462

463463
# quantization second
464464
if prefix in module_to_scheme:
465-
generator = self.quantization_compressor.decompress_from_state_dict(
466-
state_dict,
467-
names_to_scheme=module_to_scheme,
465+
state_dict = (
466+
self.quantization_compressor.decompress_module_from_state_dict(
467+
prefix,
468+
state_dict,
469+
scheme=module_to_scheme[prefix],
470+
)
468471
)
469-
# generates (mod_path, {param_name, param_val})
470-
# of compressed params and used params, but not unused params
471-
# some used params are removed by get_unexpected_file_keys
472-
state_dict = {
473-
merge_names(module_path, param_name): param_value
474-
for module_path, compressed_data in generator
475-
for param_name, param_value in compressed_data.items()
476-
}
477472

478473
# remove any existing parameters
479474
exec_device = get_execution_device(module)

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
get_nested_weight_mappings,
2525
merge_names,
2626
)
27+
from compressed_tensors.utils.safetensors_load import match_param_name
2728
from safetensors import safe_open
2829
from torch import Tensor
2930
from tqdm import tqdm
@@ -223,9 +224,7 @@ def decompress_from_state_dict(
223224
state_dict, self.compression_param_names
224225
)
225226
for module_path in weight_mappings.keys():
226-
weight_data = {}
227-
for param_name, param_value in weight_mappings[module_path].items():
228-
weight_data[param_name] = param_value
227+
weight_data = weight_mappings[module_path].copy()
229228

230229
if "weight_scale" in weight_data:
231230
quant_args = names_to_scheme[module_path].weights
@@ -234,3 +233,31 @@ def decompress_from_state_dict(
234233
)
235234
weight_data["weight"] = decompressed
236235
yield module_path, weight_data
236+
237+
def decompress_module_from_state_dict(
238+
self,
239+
prefix: str,
240+
state_dict: Dict[str, torch.Tensor],
241+
scheme: QuantizationScheme,
242+
) -> Dict[str, torch.Tensor]:
243+
"""
244+
Only used by in-memory decompression pathways to decompress the parameters of
245+
one module
246+
247+
:param prefix: prefix of state_dict, typically the path to the module
248+
:param state_dict: state dict containing module parameter values
249+
:param scheme: quantization scheme of module to decompress
250+
:return: state dict with weight decompressed if applicable
251+
"""
252+
state_dict = {
253+
key.removeprefix(f"{prefix}."): value for key, value in state_dict.items()
254+
}
255+
256+
if "weight_scale" in state_dict:
257+
state_dict["weight"] = self.decompress_weight(
258+
compressed_data=state_dict, quantization_args=scheme.weights
259+
)
260+
261+
state_dict = {f"{prefix}.{key}": value for key, value in state_dict.items()}
262+
263+
return state_dict

0 commit comments

Comments
 (0)