|
19 | 19 | import re
|
20 | 20 | from contextlib import contextmanager
|
21 | 21 | from copy import deepcopy
|
22 |
| -from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union |
| 22 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union |
23 | 23 |
|
24 | 24 | import compressed_tensors
|
25 | 25 | import torch
|
|
39 | 39 | apply_quantization_config,
|
40 | 40 | load_pretrained_quantization,
|
41 | 41 | )
|
42 |
| -from compressed_tensors.quantization.lifecycle import expand_sparse_target_names |
| 42 | +from compressed_tensors.quantization.lifecycle import expand_target_names |
43 | 43 | from compressed_tensors.quantization.quant_args import QuantizationArgs
|
44 | 44 | from compressed_tensors.quantization.utils import (
|
45 | 45 | is_module_quantized,
|
46 | 46 | iter_named_leaf_modules,
|
47 | 47 | )
|
48 |
| -from compressed_tensors.utils import get_safetensors_folder, update_parameter_data |
| 48 | +from compressed_tensors.utils import ( |
| 49 | + get_safetensors_folder, |
| 50 | + merge_names, |
| 51 | + update_parameter_data, |
| 52 | +) |
49 | 53 | from compressed_tensors.utils.helpers import (
|
50 | 54 | fix_fsdp_module_name,
|
51 | 55 | is_compressed_tensors_config,
|
@@ -254,6 +258,107 @@ def __init__(
|
254 | 258 | quantization_config.format, config=quantization_config
|
255 | 259 | )
|
256 | 260 |
|
| 261 | + def get_missing_module_keys(self, model: Module) -> List[str]: |
| 262 | + """ |
| 263 | + Identifies the expected missing weight keys in the compressed state_dict. |
| 264 | +
|
| 265 | + When a model undergoes sparsity or quantization compression, certain |
| 266 | + weight tensors may be absent from the checkpoint by virtue of compression. |
| 267 | + This function determines which weight keys are missing based on the |
| 268 | + applied compression techniques. |
| 269 | +
|
| 270 | +
|
| 271 | + :param model: The PyTorch model to check for missing keys. |
| 272 | + :return: A list of missing keys expected in the compressed state_dict. |
| 273 | + """ |
| 274 | + missing_keys = set() |
| 275 | + |
| 276 | + # Determine missing keys due to sparsity compression |
| 277 | + if ( |
| 278 | + self.sparsity_compressor |
| 279 | + and self.sparsity_config.format != CompressionFormat.dense.value |
| 280 | + ): |
| 281 | + sparse_targets = expand_target_names( |
| 282 | + model=model, |
| 283 | + targets=self.sparsity_config.targets, |
| 284 | + ignore=self.sparsity_config.ignore, |
| 285 | + ) |
| 286 | + missing_keys.update( |
| 287 | + merge_names(target, "weight") for target in sparse_targets |
| 288 | + ) |
| 289 | + |
| 290 | + # Determine missing keys due to pack quantization |
| 291 | + if ( |
| 292 | + self.quantization_compressor |
| 293 | + and self.quantization_config.format |
| 294 | + == CompressionFormat.pack_quantized.value |
| 295 | + ): |
| 296 | + for scheme in self.quantization_config.config_groups.values(): |
| 297 | + quant_targets = expand_target_names( |
| 298 | + model=model, |
| 299 | + targets=scheme.targets, |
| 300 | + ignore=self.quantization_config.ignore, |
| 301 | + ) |
| 302 | + missing_keys.update( |
| 303 | + merge_names(target, "weight") for target in quant_targets |
| 304 | + ) |
| 305 | + |
| 306 | + return list(missing_keys) |
| 307 | + |
| 308 | + def get_unexpected_file_keys(self, model: Module) -> List[str]: |
| 309 | + """ |
| 310 | + Identifies extra keys introduced by the compression process in the |
| 311 | + compressed state_dict that are not expected by the model graph. |
| 312 | +
|
| 313 | + During sparsity or quantization compression, additional metadata or |
| 314 | + auxiliary parameters may be stored in the checkpoint, which do not |
| 315 | + correspond to any parameter in the original model. These keys are |
| 316 | + typically introduced to support the reconstruction of compressed weights. |
| 317 | +
|
| 318 | + For example, Sparse24Bitmask compression may introduce keys such as |
| 319 | + 'compressed', 'bitmask', and 'shape' in the checkpoint, which are |
| 320 | + not part of the original model parameters. |
| 321 | +
|
| 322 | + :param model: The PyTorch model to check for unexpected keys. |
| 323 | + :return: A list of extra keys introduced by the compression process |
| 324 | + that are not expected by the model. |
| 325 | + """ |
| 326 | + |
| 327 | + unexpected_keys = set() |
| 328 | + |
| 329 | + # Identify unexpected keys from sparsity compression |
| 330 | + if ( |
| 331 | + self.sparsity_compressor |
| 332 | + and self.sparsity_config.format != CompressionFormat.dense.value |
| 333 | + ): |
| 334 | + sparse_targets: Set[str] = expand_target_names( |
| 335 | + model=model, |
| 336 | + targets=self.sparsity_config.targets, |
| 337 | + ignore=self.sparsity_config.ignore, |
| 338 | + ) |
| 339 | + unexpected_keys.update( |
| 340 | + merge_names(target, param) |
| 341 | + for target in sparse_targets |
| 342 | + for param in self.sparsity_compressor.compression_param_names |
| 343 | + ) |
| 344 | + |
| 345 | + # Identify unexpected keys from quantization compression |
| 346 | + if self.quantization_compressor: |
| 347 | + for scheme in self.quantization_config.config_groups.values(): |
| 348 | + quant_targets: Set[str] = expand_target_names( |
| 349 | + model=model, |
| 350 | + targets=scheme.targets, |
| 351 | + ignore=self.quantization_config.ignore, |
| 352 | + ) |
| 353 | + unexpected_keys.update( |
| 354 | + merge_names(target, param) |
| 355 | + for target in quant_targets |
| 356 | + for param in self.quantization_compressor.compression_param_names |
| 357 | + if param != "weight" |
| 358 | + ) |
| 359 | + |
| 360 | + return list(unexpected_keys) |
| 361 | + |
257 | 362 | def compress(
|
258 | 363 | self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
|
259 | 364 | ) -> Dict[str, Tensor]:
|
@@ -283,7 +388,7 @@ def compress(
|
283 | 388 | )
|
284 | 389 |
|
285 | 390 | if self.sparsity_compressor is not None:
|
286 |
| - sparse_compression_targets: Set[str] = expand_sparse_target_names( |
| 391 | + sparse_compression_targets: Set[str] = expand_target_names( |
287 | 392 | model=model,
|
288 | 393 | targets=self.sparsity_config.targets,
|
289 | 394 | ignore=self.sparsity_config.ignore,
|
|
0 commit comments