Skip to content

Commit c354826

Browse files
authored
Add: missing and unexpected keys in ModelCompressor (#250)
* Add missing and unexpected keys methods in ModelCompressor * Update missing, unexpected keys to take care of ignores and targets Rename is_sparse_target -> is_target Rename expand_sparse_target -> expand_target since these are more general functions * Add tests * Move private methods below public * Fix nits * Address: Review Comment from @dsikka Signed-off-by: Rahul Tuli <rahul@neuralmagic.com> --------- Signed-off-by: Rahul Tuli <rahul@neuralmagic.com>
1 parent 5f24384 commit c354826

File tree

4 files changed

+309
-90
lines changed

4 files changed

+309
-90
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import re
2020
from contextlib import contextmanager
2121
from copy import deepcopy
22-
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union
22+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union
2323

2424
import compressed_tensors
2525
import torch
@@ -39,13 +39,17 @@
3939
apply_quantization_config,
4040
load_pretrained_quantization,
4141
)
42-
from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
42+
from compressed_tensors.quantization.lifecycle import expand_target_names
4343
from compressed_tensors.quantization.quant_args import QuantizationArgs
4444
from compressed_tensors.quantization.utils import (
4545
is_module_quantized,
4646
iter_named_leaf_modules,
4747
)
48-
from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
48+
from compressed_tensors.utils import (
49+
get_safetensors_folder,
50+
merge_names,
51+
update_parameter_data,
52+
)
4953
from compressed_tensors.utils.helpers import (
5054
fix_fsdp_module_name,
5155
is_compressed_tensors_config,
@@ -254,6 +258,107 @@ def __init__(
254258
quantization_config.format, config=quantization_config
255259
)
256260

261+
def get_missing_module_keys(self, model: Module) -> List[str]:
262+
"""
263+
Identifies the expected missing weight keys in the compressed state_dict.
264+
265+
When a model undergoes sparsity or quantization compression, certain
266+
weight tensors may be absent from the checkpoint by virtue of compression.
267+
This function determines which weight keys are missing based on the
268+
applied compression techniques.
269+
270+
271+
:param model: The PyTorch model to check for missing keys.
272+
:return: A list of missing keys expected in the compressed state_dict.
273+
"""
274+
missing_keys = set()
275+
276+
# Determine missing keys due to sparsity compression
277+
if (
278+
self.sparsity_compressor
279+
and self.sparsity_config.format != CompressionFormat.dense.value
280+
):
281+
sparse_targets = expand_target_names(
282+
model=model,
283+
targets=self.sparsity_config.targets,
284+
ignore=self.sparsity_config.ignore,
285+
)
286+
missing_keys.update(
287+
merge_names(target, "weight") for target in sparse_targets
288+
)
289+
290+
# Determine missing keys due to pack quantization
291+
if (
292+
self.quantization_compressor
293+
and self.quantization_config.format
294+
== CompressionFormat.pack_quantized.value
295+
):
296+
for scheme in self.quantization_config.config_groups.values():
297+
quant_targets = expand_target_names(
298+
model=model,
299+
targets=scheme.targets,
300+
ignore=self.quantization_config.ignore,
301+
)
302+
missing_keys.update(
303+
merge_names(target, "weight") for target in quant_targets
304+
)
305+
306+
return list(missing_keys)
307+
308+
def get_unexpected_file_keys(self, model: Module) -> List[str]:
309+
"""
310+
Identifies extra keys introduced by the compression process in the
311+
compressed state_dict that are not expected by the model graph.
312+
313+
During sparsity or quantization compression, additional metadata or
314+
auxiliary parameters may be stored in the checkpoint, which do not
315+
correspond to any parameter in the original model. These keys are
316+
typically introduced to support the reconstruction of compressed weights.
317+
318+
For example, Sparse24Bitmask compression may introduce keys such as
319+
'compressed', 'bitmask', and 'shape' in the checkpoint, which are
320+
not part of the original model parameters.
321+
322+
:param model: The PyTorch model to check for unexpected keys.
323+
:return: A list of extra keys introduced by the compression process
324+
that are not expected by the model.
325+
"""
326+
327+
unexpected_keys = set()
328+
329+
# Identify unexpected keys from sparsity compression
330+
if (
331+
self.sparsity_compressor
332+
and self.sparsity_config.format != CompressionFormat.dense.value
333+
):
334+
sparse_targets: Set[str] = expand_target_names(
335+
model=model,
336+
targets=self.sparsity_config.targets,
337+
ignore=self.sparsity_config.ignore,
338+
)
339+
unexpected_keys.update(
340+
merge_names(target, param)
341+
for target in sparse_targets
342+
for param in self.sparsity_compressor.compression_param_names
343+
)
344+
345+
# Identify unexpected keys from quantization compression
346+
if self.quantization_compressor:
347+
for scheme in self.quantization_config.config_groups.values():
348+
quant_targets: Set[str] = expand_target_names(
349+
model=model,
350+
targets=scheme.targets,
351+
ignore=self.quantization_config.ignore,
352+
)
353+
unexpected_keys.update(
354+
merge_names(target, param)
355+
for target in quant_targets
356+
for param in self.quantization_compressor.compression_param_names
357+
if param != "weight"
358+
)
359+
360+
return list(unexpected_keys)
361+
257362
def compress(
258363
self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
259364
) -> Dict[str, Tensor]:
@@ -283,7 +388,7 @@ def compress(
283388
)
284389

285390
if self.sparsity_compressor is not None:
286-
sparse_compression_targets: Set[str] = expand_sparse_target_names(
391+
sparse_compression_targets: Set[str] = expand_target_names(
287392
model=model,
288393
targets=self.sparsity_config.targets,
289394
ignore=self.sparsity_config.ignore,

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@
5252
"apply_quantization_config",
5353
"apply_quantization_status",
5454
"find_name_or_class_matches",
55-
"expand_sparse_target_names",
56-
"is_sparse_target",
55+
"expand_target_names",
56+
"is_target",
5757
]
5858

5959
from compressed_tensors.quantization.utils.helpers import is_module_quantized
@@ -247,8 +247,10 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
247247
model.apply(compress_quantized_weights)
248248

249249

250-
def expand_sparse_target_names(
251-
model: Module, targets: Iterable[str], ignore: Iterable[str]
250+
def expand_target_names(
251+
model: Module,
252+
targets: Optional[Iterable[str]] = None,
253+
ignore: Optional[Iterable[str]] = None,
252254
) -> Set[str]:
253255
"""
254256
Finds all unique module names in the model that match the given
@@ -257,20 +259,23 @@ def expand_sparse_target_names(
257259
Note: Targets must be regexes, layer types, or full layer names.
258260
259261
:param model: model to search for targets in
260-
:param targets: list of targets to search for
261-
:param ignore: list of targets to ignore
262+
:param targets: Iterable of targets to search for
263+
:param ignore: Iterable of targets to ignore
262264
:return: set of all targets that match the given targets and should
263265
not be ignored
264266
"""
265267
return {
266268
name
267269
for name, module in iter_named_leaf_modules(model)
268-
if is_sparse_target(name, module, targets, ignore)
270+
if is_target(name, module, targets, ignore)
269271
}
270272

271273

272-
def is_sparse_target(
273-
name: str, module: Module, targets: Iterable[str], ignore: Iterable[str]
274+
def is_target(
275+
name: str,
276+
module: Module,
277+
targets: Optional[Iterable[str]] = None,
278+
ignore: Optional[Iterable[str]] = None,
274279
) -> bool:
275280
"""
276281
Determines if a module should be included in the targets based on the
@@ -280,12 +285,12 @@ def is_sparse_target(
280285
281286
:param name: name of the module
282287
:param module: the module itself
283-
:param targets: list of targets to search for
284-
:param ignore: list of targets to ignore
288+
:param targets: Iterable of targets to search for
289+
:param ignore: Iterable of targets to ignore
285290
:return: True if the module is a target and not ignored, False otherwise
286291
"""
287292
return bool(
288-
find_name_or_class_matches(name, module, targets)
293+
find_name_or_class_matches(name, module, targets or [])
289294
and not find_name_or_class_matches(name, module, ignore or [])
290295
)
291296

0 commit comments

Comments
 (0)