42
42
SPARSITY_CONFIG_NAME ,
43
43
)
44
44
from compressed_tensors .compressors .base import BaseCompressor
45
+ from compressed_tensors .compressors .sparse_compressors import DenseCompressor
45
46
from compressed_tensors .config import CompressionFormat , SparsityCompressionConfig
46
47
from compressed_tensors .linear .compressed_linear import CompressedLinear
47
48
from compressed_tensors .quantization import (
50
51
QuantizationScheme ,
51
52
QuantizationStatus ,
52
53
apply_quantization_config ,
53
- load_pretrained_quantization ,
54
+ load_pretrained_quantization_parameters ,
54
55
)
55
56
from compressed_tensors .quantization .lifecycle import expand_target_names
56
57
from compressed_tensors .quantization .utils import (
59
60
)
60
61
from compressed_tensors .utils import (
61
62
get_safetensors_folder ,
63
+ has_offloaded_params ,
62
64
merge_names ,
63
65
module_replace_dfs ,
66
+ register_offload_parameter ,
64
67
update_parameter_data ,
65
68
)
66
69
from compressed_tensors .utils .helpers import (
@@ -448,6 +451,13 @@ def decompress(self, model_path: str, model: Module):
448
451
449
452
:param model_path: path to compressed weights
450
453
:param model: pytorch model to load decompressed weights into
454
+
455
+ Note: decompress makes use of both _replace_sparsity_weights and _replace_weights
456
+ The variations in these methods are a result of the subtle variations between the sparsity
457
+ and quantization compressors. Specifically, quantization compressors return not just the
458
+ decompressed weight, but the quantization parameters (e.g scales, zero_point) whereas sparsity
459
+ compressors only return the decompressed weight.
460
+
451
461
"""
452
462
model_path = get_safetensors_folder (model_path )
453
463
sparse_decompressed = False
@@ -456,9 +466,16 @@ def decompress(self, model_path: str, model: Module):
456
466
self .sparsity_compressor is not None
457
467
and self .sparsity_config .format != CompressionFormat .dense .value
458
468
):
469
+ params_to_ignore = None
470
+ if self .quantization_compressor is not None :
471
+ params_to_ignore = self .quantization_compressor .compression_param_names
459
472
# Sparse decompression is applied on the model_path
460
- dense_gen = self .sparsity_compressor .decompress (model_path )
461
- self ._replace_weights (dense_gen , model )
473
+ # The compressor will try and load any quantization parameters as well
474
+ # params_to_skip_load will skip over quantization params from being loaded
475
+ dense_gen = self .sparsity_compressor .decompress (
476
+ model_path , params_to_skip_load = params_to_ignore
477
+ )
478
+ self ._replace_sparsity_weights (dense_gen , model )
462
479
setattr (model , SPARSITY_CONFIG_NAME , self .sparsity_compressor .config )
463
480
sparse_decompressed = True
464
481
@@ -467,13 +484,27 @@ def decompress(self, model_path: str, model: Module):
467
484
# quantization during apply_quantization_config. This ensures
468
485
# that the dtypes of the weights are not unintentionally updated.
469
486
# The status is restored after quantization params are loaded.
487
+
470
488
with override_quantization_status (
471
489
self .quantization_config , QuantizationStatus .FROZEN
472
490
):
491
+
473
492
names_to_scheme = apply_quantization_config (
474
493
model , self .quantization_config
475
494
)
476
- load_pretrained_quantization (model , model_path )
495
+ # Load activation scales/zp or any other quantization parameters
496
+ # Conditionally load the weight quantization parameters if we have a dense compressor
497
+ # Or if a sparsity compressor has already been applied
498
+ load_pretrained_quantization_parameters (
499
+ model ,
500
+ model_path ,
501
+ # TODO: all weight quantization params will be moved to the compressor in a follow-up
502
+ # including initialization
503
+ load_weight_quantization = (
504
+ sparse_decompressed
505
+ or isinstance (self .quantization_compressor , DenseCompressor )
506
+ ),
507
+ )
477
508
478
509
model_path_or_state_dict = (
479
510
model .state_dict () if sparse_decompressed else model_path
@@ -482,6 +513,8 @@ def decompress(self, model_path: str, model: Module):
482
513
dense_gen = self .quantization_compressor .decompress (
483
514
model_path_or_state_dict , names_to_scheme = names_to_scheme
484
515
)
516
+ # TODO: all weight quantization params will be moved to the compressor
517
+ # to prevent duplicate parameter updates in update_parameter_data
485
518
self ._replace_weights (dense_gen , model )
486
519
487
520
def freeze_quantization_status (module ):
@@ -537,7 +570,7 @@ def update_config(self, save_directory: str):
537
570
with open (config_file_path , "w" ) as config_file :
538
571
json .dump (config_data , config_file , indent = 2 , sort_keys = True )
539
572
540
- def _replace_weights (self , dense_weight_generator , model : Module ):
573
+ def _replace_sparsity_weights (self , dense_weight_generator , model : Module ):
541
574
"""
542
575
Replace the weights of the model with the
543
576
provided dense weights.
@@ -552,11 +585,60 @@ def _replace_weights(self, dense_weight_generator, model: Module):
552
585
:param model: The model whose weights are to be updated.
553
586
"""
554
587
for name , data in tqdm (dense_weight_generator , desc = "Decompressing model" ):
588
+
555
589
split_name = name .split ("." )
556
590
prefix , param_name = "." .join (split_name [:- 1 ]), split_name [- 1 ]
557
591
module = operator .attrgetter (prefix )(model )
558
- if hasattr (module , param_name ):
559
- update_parameter_data (module , data , param_name )
592
+
593
+ params_device = next (module .parameters ()).device
594
+ device = "cpu" if has_offloaded_params (module ) else params_device
595
+ delattr (module , param_name )
596
+ requires_grad = data .dtype in (torch .float16 , torch .float32 , torch .bfloat16 )
597
+ param = torch .nn .Parameter (data .to (device ), requires_grad = requires_grad )
598
+ register_offload_parameter (module , param_name , param )
599
+
600
+ def _replace_weights (self , dense_weight_generator , model : Module ):
601
+ """
602
+ Replace the weights of the model with the
603
+ provided dense weights.
604
+
605
+ This method iterates over the dense_weight_generator and
606
+ updates the corresponding weights in the model. If a parameter
607
+ name does not exist in the model, it will be skipped.
608
+
609
+ :param dense_weight_generator (generator): A generator that yields
610
+ tuples of (name, data), where 'name' is the parameter name and
611
+ 'data' is the updated param data
612
+ :param model: The model whose weights are to be updated.
613
+ """
614
+
615
+ for name , data in tqdm (dense_weight_generator , desc = "Decompressing model" ):
616
+ module = operator .attrgetter (name )(model )
617
+
618
+ params_device = next (module .parameters ()).device
619
+ device = "cpu" if has_offloaded_params (module ) else params_device
620
+
621
+ for param_name , param_data in data .items ():
622
+ if hasattr (module , param_name ):
623
+ # If compressed, will have an incorrect dtype for transformers >4.49
624
+ # TODO: we can also just skip initialization of scales/zp if in decompression in init
625
+ # to be consistent with loading which happens later as well
626
+ # however, update_data does a good shape check - should be moved to the compressor
627
+ if param_name == "weight" :
628
+ delattr (module , param_name )
629
+ requires_grad = param_data .dtype in (
630
+ torch .float16 ,
631
+ torch .float32 ,
632
+ torch .bfloat16 ,
633
+ )
634
+ param = torch .nn .Parameter (
635
+ param_data .to (device ), requires_grad = requires_grad
636
+ )
637
+ register_offload_parameter (module , param_name , param )
638
+ else :
639
+ # Should already be registered to the correct device for
640
+ # for scales/zero-points
641
+ update_parameter_data (module , param_data , param_name )
560
642
561
643
562
644
def map_module_to_scheme (model : Module ) -> Dict [str , QuantizationScheme ]:
0 commit comments