4747 iter_named_leaf_modules ,
4848)
4949from compressed_tensors .utils import (
50+ align_module_device ,
51+ delete_offload_parameter ,
52+ get_execution_device ,
5053 get_safetensors_folder ,
5154 has_offloaded_params ,
5255 merge_names ,
@@ -98,6 +101,9 @@ class ModelCompressor:
98101 :param quantization_config: config specifying quantization compression parameters
99102 """
100103
104+ sparsity_config : Optional [SparsityCompressionConfig ] = None
105+ quantization_config : Optional [QuantizationConfig ] = None
106+
101107 @classmethod
102108 def from_pretrained (
103109 cls ,
@@ -261,6 +267,8 @@ def __init__(
261267 quantization_config .format , config = quantization_config
262268 )
263269
270+ # ----- used by hf quantizer ----- #
271+
264272 def get_missing_module_keys (self , model : Module ) -> List [str ]:
265273 """
266274 Identifies the expected missing weight keys in the compressed state_dict.
@@ -270,7 +278,6 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
270278 This function determines which weight keys are missing based on the
271279 applied compression techniques.
272280
273-
274281 :param model: The PyTorch model to check for missing keys.
275282 :return: A list of missing keys expected in the compressed state_dict.
276283 """
@@ -362,8 +369,124 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
362369
363370 return list (unexpected_keys )
364371
372+ # ----- model memory compression/decompression pathways ----- #
373+
374+ def compress_model (self , model : Module ):
375+ """
376+ Compress a model in memory. Because the model structure is modified in place,
377+ this method is more memory-efficient than `self.compress`
378+
379+ :param model: model containing parameters to compress
380+ """
381+ module_to_scheme = map_module_to_scheme (model )
382+ sparse_compression_targets : Set [str ] = expand_target_names (
383+ model = model ,
384+ targets = self .sparsity_config .targets if self .sparsity_config else [],
385+ ignore = self .sparsity_config .ignore if self .sparsity_config else [],
386+ )
387+
388+ for prefix , module in tqdm (model .named_modules (), desc = "Compressing model" ):
389+ if prefix in module_to_scheme or prefix in sparse_compression_targets :
390+ # in the future, support compression on same device
391+ with align_module_device (module , execution_device = "cpu" ):
392+ state_dict = module .state_dict (prefix = f"{ prefix } ." )
393+
394+ # quantization first
395+ if prefix in module_to_scheme :
396+ state_dict = self .quantization_compressor .compress (
397+ state_dict ,
398+ names_to_scheme = module_to_scheme ,
399+ show_progress = False ,
400+ )
401+
402+ # sparsity second
403+ if prefix in sparse_compression_targets :
404+ state_dict = self .sparsity_compressor .compress (
405+ state_dict ,
406+ compression_targets = sparse_compression_targets ,
407+ show_progress = False ,
408+ )
409+
410+ # remove any existing parameters
411+ device = get_execution_device (module )
412+ for name , _ in list (module .named_parameters ()):
413+ delattr (module , name )
414+
415+ # replace with compressed parameters
416+ for name , value in state_dict .items ():
417+ name = name .removeprefix (f"{ prefix } ." )
418+ value = value .to (device )
419+ param = torch .nn .Parameter (value , requires_grad = False )
420+ register_offload_parameter (module , name , param )
421+
422+ module .quantization_status = QuantizationStatus .COMPRESSED
423+
424+ def decompress_model (self , model : Module ):
425+ """
426+ Decompress a model in memory. Because the model structure is modified in place,
427+ this method does not require loading some compression parameters from disk
428+
429+ :param model: model containing parameters to compress
430+ """
431+ module_to_scheme = map_module_to_scheme (model )
432+ sparse_compression_targets : Set [str ] = expand_target_names (
433+ model = model ,
434+ targets = self .sparsity_config .targets if self .sparsity_config else [],
435+ ignore = self .sparsity_config .ignore if self .sparsity_config else [],
436+ )
437+
438+ for prefix , module in tqdm (model .named_modules (), desc = "Decompressing model" ):
439+ if prefix in module_to_scheme or prefix in sparse_compression_targets :
440+ # in the future, support decompression on same device
441+ with align_module_device (module , execution_device = "cpu" ):
442+ state_dict = module .state_dict (prefix = f"{ prefix } ." )
443+
444+ # sparsity first
445+ if prefix in sparse_compression_targets :
446+ # sparse_compression_targets are automatically inferred by this fn
447+ generator = self .sparsity_compressor .decompress_from_state_dict (
448+ state_dict ,
449+ )
450+ # generates (param_path, param_val)
451+ # of compressed and unused params
452+ state_dict = {key : value for key , value in generator }
453+
454+ # quantization second
455+ if prefix in module_to_scheme :
456+ generator = self .quantization_compressor .decompress_from_state_dict (
457+ state_dict ,
458+ names_to_scheme = module_to_scheme ,
459+ )
460+ # generates (mod_path, {param_name, param_val})
461+ # of compressed params and used params, but not unused params
462+ # some used params are removed by get_unexpected_file_keys
463+ state_dict = {
464+ merge_names (module_path , param_name ): param_value
465+ for module_path , compressed_data in generator
466+ for param_name , param_value in compressed_data .items ()
467+ }
468+
469+ # remove any existing parameters
470+ device = get_execution_device (module )
471+ for name , _ in list (module .named_parameters ()):
472+ delete_offload_parameter (module , name )
473+
474+ # replace with decompressed parameters
475+ for name , value in state_dict .items ():
476+ name = name .removeprefix (f"{ prefix } ." )
477+ value = value .to (device )
478+ param = torch .nn .Parameter (value , requires_grad = False )
479+ register_offload_parameter (module , name , param )
480+
481+ module .quantization_status = QuantizationStatus .FROZEN
482+
483+ # ----- state dict compression pathways ----- #
484+
365485 def compress (
366- self , model : Module , state_dict : Optional [Dict [str , Tensor ]] = None
486+ self ,
487+ model : Module ,
488+ state_dict : Optional [Dict [str , Tensor ]] = None ,
489+ show_progress : bool = False ,
367490 ) -> Dict [str , Tensor ]:
368491 """
369492 Compresses a dense state dict or model with sparsity and/or quantization
@@ -379,7 +502,9 @@ def compress(
379502 if self .quantization_compressor is not None :
380503 module_to_scheme = map_module_to_scheme (model )
381504 state_dict = self .quantization_compressor .compress (
382- state_dict , names_to_scheme = module_to_scheme
505+ state_dict ,
506+ names_to_scheme = module_to_scheme ,
507+ show_progress = show_progress ,
383508 )
384509
385510 # TODO: consider sparse compression to also be compression
@@ -397,6 +522,7 @@ def compress(
397522 state_dict = self .sparsity_compressor .compress (
398523 state_dict ,
399524 compression_targets = sparse_compression_targets ,
525+ show_progress = show_progress ,
400526 )
401527
402528 # HACK: Override the dtype_byte_size function in transformers to
@@ -406,6 +532,8 @@ def compress(
406532
407533 return state_dict
408534
535+ # ----- disk decompression pathways ----- #
536+
409537 def decompress (self , model_path : str , model : Module ):
410538 """
411539 Overwrites the weights in model with weights decompressed from model_path
0 commit comments