47
47
iter_named_leaf_modules ,
48
48
)
49
49
from compressed_tensors .utils import (
50
+ align_module_device ,
51
+ delete_offload_parameter ,
52
+ get_execution_device ,
50
53
get_safetensors_folder ,
51
54
has_offloaded_params ,
52
55
merge_names ,
@@ -98,6 +101,9 @@ class ModelCompressor:
98
101
:param quantization_config: config specifying quantization compression parameters
99
102
"""
100
103
104
+ sparsity_config : Optional [SparsityCompressionConfig ] = None
105
+ quantization_config : Optional [QuantizationConfig ] = None
106
+
101
107
@classmethod
102
108
def from_pretrained (
103
109
cls ,
@@ -261,6 +267,8 @@ def __init__(
261
267
quantization_config .format , config = quantization_config
262
268
)
263
269
270
+ # ----- used by hf quantizer ----- #
271
+
264
272
def get_missing_module_keys (self , model : Module ) -> List [str ]:
265
273
"""
266
274
Identifies the expected missing weight keys in the compressed state_dict.
@@ -270,7 +278,6 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
270
278
This function determines which weight keys are missing based on the
271
279
applied compression techniques.
272
280
273
-
274
281
:param model: The PyTorch model to check for missing keys.
275
282
:return: A list of missing keys expected in the compressed state_dict.
276
283
"""
@@ -362,8 +369,124 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
362
369
363
370
return list (unexpected_keys )
364
371
372
+ # ----- model memory compression/decompression pathways ----- #
373
+
374
+ def compress_model (self , model : Module ):
375
+ """
376
+ Compress a model in memory. Because the model structure is modified in place,
377
+ this method is more memory-efficient than `self.compress`
378
+
379
+ :param model: model containing parameters to compress
380
+ """
381
+ module_to_scheme = map_module_to_scheme (model )
382
+ sparse_compression_targets : Set [str ] = expand_target_names (
383
+ model = model ,
384
+ targets = self .sparsity_config .targets if self .sparsity_config else [],
385
+ ignore = self .sparsity_config .ignore if self .sparsity_config else [],
386
+ )
387
+
388
+ for prefix , module in tqdm (model .named_modules (), desc = "Compressing model" ):
389
+ if prefix in module_to_scheme or prefix in sparse_compression_targets :
390
+ # in the future, support compression on same device
391
+ with align_module_device (module , execution_device = "cpu" ):
392
+ state_dict = module .state_dict (prefix = f"{ prefix } ." )
393
+
394
+ # quantization first
395
+ if prefix in module_to_scheme :
396
+ state_dict = self .quantization_compressor .compress (
397
+ state_dict ,
398
+ names_to_scheme = module_to_scheme ,
399
+ show_progress = False ,
400
+ )
401
+
402
+ # sparsity second
403
+ if prefix in sparse_compression_targets :
404
+ state_dict = self .sparsity_compressor .compress (
405
+ state_dict ,
406
+ compression_targets = sparse_compression_targets ,
407
+ show_progress = False ,
408
+ )
409
+
410
+ # remove any existing parameters
411
+ device = get_execution_device (module )
412
+ for name , _ in list (module .named_parameters ()):
413
+ delattr (module , name )
414
+
415
+ # replace with compressed parameters
416
+ for name , value in state_dict .items ():
417
+ name = name .removeprefix (f"{ prefix } ." )
418
+ value = value .to (device )
419
+ param = torch .nn .Parameter (value , requires_grad = False )
420
+ register_offload_parameter (module , name , param )
421
+
422
+ module .quantization_status = QuantizationStatus .COMPRESSED
423
+
424
+ def decompress_model (self , model : Module ):
425
+ """
426
+ Decompress a model in memory. Because the model structure is modified in place,
427
+ this method does not require loading some compression parameters from disk
428
+
429
+ :param model: model containing parameters to compress
430
+ """
431
+ module_to_scheme = map_module_to_scheme (model )
432
+ sparse_compression_targets : Set [str ] = expand_target_names (
433
+ model = model ,
434
+ targets = self .sparsity_config .targets if self .sparsity_config else [],
435
+ ignore = self .sparsity_config .ignore if self .sparsity_config else [],
436
+ )
437
+
438
+ for prefix , module in tqdm (model .named_modules (), desc = "Decompressing model" ):
439
+ if prefix in module_to_scheme or prefix in sparse_compression_targets :
440
+ # in the future, support decompression on same device
441
+ with align_module_device (module , execution_device = "cpu" ):
442
+ state_dict = module .state_dict (prefix = f"{ prefix } ." )
443
+
444
+ # sparsity first
445
+ if prefix in sparse_compression_targets :
446
+ # sparse_compression_targets are automatically inferred by this fn
447
+ generator = self .sparsity_compressor .decompress_from_state_dict (
448
+ state_dict ,
449
+ )
450
+ # generates (param_path, param_val)
451
+ # of compressed and unused params
452
+ state_dict = {key : value for key , value in generator }
453
+
454
+ # quantization second
455
+ if prefix in module_to_scheme :
456
+ generator = self .quantization_compressor .decompress_from_state_dict (
457
+ state_dict ,
458
+ names_to_scheme = module_to_scheme ,
459
+ )
460
+ # generates (mod_path, {param_name, param_val})
461
+ # of compressed params and used params, but not unused params
462
+ # some used params are removed by get_unexpected_file_keys
463
+ state_dict = {
464
+ merge_names (module_path , param_name ): param_value
465
+ for module_path , compressed_data in generator
466
+ for param_name , param_value in compressed_data .items ()
467
+ }
468
+
469
+ # remove any existing parameters
470
+ device = get_execution_device (module )
471
+ for name , _ in list (module .named_parameters ()):
472
+ delete_offload_parameter (module , name )
473
+
474
+ # replace with decompressed parameters
475
+ for name , value in state_dict .items ():
476
+ name = name .removeprefix (f"{ prefix } ." )
477
+ value = value .to (device )
478
+ param = torch .nn .Parameter (value , requires_grad = False )
479
+ register_offload_parameter (module , name , param )
480
+
481
+ module .quantization_status = QuantizationStatus .FROZEN
482
+
483
+ # ----- state dict compression pathways ----- #
484
+
365
485
def compress (
366
- self , model : Module , state_dict : Optional [Dict [str , Tensor ]] = None
486
+ self ,
487
+ model : Module ,
488
+ state_dict : Optional [Dict [str , Tensor ]] = None ,
489
+ show_progress : bool = False ,
367
490
) -> Dict [str , Tensor ]:
368
491
"""
369
492
Compresses a dense state dict or model with sparsity and/or quantization
@@ -379,7 +502,9 @@ def compress(
379
502
if self .quantization_compressor is not None :
380
503
module_to_scheme = map_module_to_scheme (model )
381
504
state_dict = self .quantization_compressor .compress (
382
- state_dict , names_to_scheme = module_to_scheme
505
+ state_dict ,
506
+ names_to_scheme = module_to_scheme ,
507
+ show_progress = show_progress ,
383
508
)
384
509
385
510
# TODO: consider sparse compression to also be compression
@@ -397,6 +522,7 @@ def compress(
397
522
state_dict = self .sparsity_compressor .compress (
398
523
state_dict ,
399
524
compression_targets = sparse_compression_targets ,
525
+ show_progress = show_progress ,
400
526
)
401
527
402
528
# HACK: Override the dtype_byte_size function in transformers to
@@ -406,6 +532,8 @@ def compress(
406
532
407
533
return state_dict
408
534
535
+ # ----- disk decompression pathways ----- #
536
+
409
537
def decompress (self , model_path : str , model : Module ):
410
538
"""
411
539
Overwrites the weights in model with weights decompressed from model_path
0 commit comments