19
19
import re
20
20
from contextlib import contextmanager
21
21
from copy import deepcopy
22
- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Set , Tuple , TypeVar , Union
22
+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Set , Tuple , TypeVar , Union , Callable
23
23
24
24
import compressed_tensors
25
+ from compressed_tensors .linear .compressed_linear import CompressedLinear
26
+ from compressed_tensors .utils .offload import update_offload_parameter
25
27
import torch
26
28
import transformers
27
29
from compressed_tensors .base import (
65
67
66
68
_LOGGER : logging .Logger = logging .getLogger (__name__ )
67
69
68
- import tracemalloc
69
- import linecache
70
- import objgraph
70
+ def module_replace_dfs (
71
+ module : Module ,
72
+ func : Callable [[Module ], Module ],
73
+ pre : bool = True ,
74
+ progress : Union [bool , tqdm ] = False ,
75
+ ) -> Module :
76
+ if progress is True :
77
+ total = len (list (module .modules ()))
78
+ progress = tqdm (total = total )
79
+
80
+ if pre :
81
+ module = func (module )
82
+
83
+ for name , child in list (module .named_children ()):
84
+ module .add_module (name , module_replace_dfs (child , func , pre , progress ))
85
+
86
+ if not pre :
87
+ module = func (module )
88
+
89
+ if isinstance (progress , tqdm ):
90
+ progress .update (1 )
91
+
92
+ return module
93
+
94
+
71
95
72
96
if TYPE_CHECKING :
73
97
# dummy type if not available from transformers
74
98
CompressedTensorsConfig = TypeVar ("CompressedTensorsConfig" )
75
99
76
- def display_top (snapshot , key_type = 'lineno' , limit = 3 ):
77
- snapshot = snapshot .filter_traces ((
78
- tracemalloc .Filter (False , "<frozen importlib._bootstrap>" ),
79
- tracemalloc .Filter (False , "<unknown>" ),
80
- ))
81
- top_stats = snapshot .statistics (key_type )
82
-
83
- print ("Top %s lines" % limit )
84
- for index , stat in enumerate (top_stats [:limit ], 1 ):
85
- frame = stat .traceback [0 ]
86
- print ("#%s: %s:%s: %.1f MB"
87
- % (index , frame .filename , frame .lineno , stat .size / (1024 * 1024 )))
88
- line = linecache .getline (frame .filename , frame .lineno ).strip ()
89
- if line :
90
- print (' %s' % line )
91
-
92
- other = top_stats [limit :]
93
- if other :
94
- size = sum (stat .size for stat in other )
95
- print ("%s other: %.1f MB" % (len (other ), size / (1024 * 1024 )))
96
- total = sum (stat .size for stat in top_stats )
97
- print (f"Total Python-tracked memory: { total / (1024 * 1024 ):.2f} MB" )
98
-
99
100
100
101
class ModelCompressor :
101
102
"""
@@ -384,6 +385,30 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
384
385
)
385
386
386
387
return list (unexpected_keys )
388
+
389
+ def apply_compression_status (self , model : Module ) -> Module :
390
+ quantization_format = self .quantization_config .format
391
+
392
+ def replace_with_compressed (module : Module ) -> Module :
393
+ scheme = getattr (module , "quantization_scheme" , None )
394
+ if isinstance (module , torch .nn .Linear ) and scheme is not None :
395
+ #compressed_state_dict_2 = self.compress(module) # debug
396
+
397
+ module = CompressedLinear .from_linear (
398
+ module ,
399
+ quantization_scheme = scheme ,
400
+ quantization_format = quantization_format
401
+ )
402
+ state_dict = module .compressor .compress (module .state_dict (), {"" : scheme }) # added by compressed linear
403
+
404
+ for name , value in state_dict .items ():
405
+ update_offload_parameter (module , name , value )
406
+
407
+ return module
408
+
409
+
410
+ progress = tqdm (total = len (list (model .modules ())))
411
+ return module_replace_dfs (model , replace_with_compressed , progress = progress )
387
412
388
413
def compress (
389
414
self , model : Module , state_dict : Optional [Dict [str , Tensor ]] = None
@@ -403,13 +428,11 @@ def compress(
403
428
404
429
if self .quantization_compressor is not None :
405
430
#with profile(activities=[ProfilerActivity.CUDA], profile_memory=True, record_shapes=True, with_stack=True) as prof:
406
- with TrackTensorAllocations () as prof :
407
- module_to_scheme = map_module_to_scheme (model )
408
- state_dict = self .quantization_compressor .compress (
409
- state_dict , names_to_scheme = module_to_scheme
410
- )
411
- print (prof .total_tensor_memory_mib )
412
- breakpoint ()
431
+ #with TrackTensorAllocations() as prof:
432
+ module_to_scheme = map_module_to_scheme (model )
433
+ state_dict = self .quantization_compressor .compress (
434
+ state_dict , names_to_scheme = module_to_scheme
435
+ )
413
436
# if self.quantization_config.format != CompressionFormat.dense.value:
414
437
# self.quantization_config.quantization_status = (
415
438
# QuantizationStatus.COMPRESSED
@@ -559,13 +582,11 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
559
582
"""
560
583
Returns a dictionary which maps quantized module names to their quantization schemes
561
584
"""
562
- quantized_modules_to_args = {}
563
- for name , submodule in iter_named_leaf_modules (model ):
564
- if is_module_quantized (submodule ):
565
- name = fix_fsdp_module_name (name )
566
- quantized_modules_to_args [name ] = submodule .quantization_scheme
567
-
568
- return quantized_modules_to_args
585
+ return {
586
+ fix_fsdp_module_name (name ): module .quantization_scheme
587
+ for name , module in iter_named_leaf_modules (model )
588
+ if is_module_quantized (module )
589
+ }
569
590
570
591
571
592
# HACK: Override the dtype_byte_size function in transformers to support float8 types
0 commit comments