Skip to content

Commit 72dd867

Browse files
committed
wip
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent ac27709 commit 72dd867

File tree

4 files changed

+115
-8
lines changed

4 files changed

+115
-8
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,37 @@
6565

6666
_LOGGER: logging.Logger = logging.getLogger(__name__)
6767

68+
import tracemalloc
69+
import linecache
70+
import objgraph
6871

6972
if TYPE_CHECKING:
7073
# dummy type if not available from transformers
7174
CompressedTensorsConfig = TypeVar("CompressedTensorsConfig")
7275

76+
def display_top(snapshot, key_type='lineno', limit=3):
77+
snapshot = snapshot.filter_traces((
78+
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
79+
tracemalloc.Filter(False, "<unknown>"),
80+
))
81+
top_stats = snapshot.statistics(key_type)
82+
83+
print("Top %s lines" % limit)
84+
for index, stat in enumerate(top_stats[:limit], 1):
85+
frame = stat.traceback[0]
86+
print("#%s: %s:%s: %.1f MB"
87+
% (index, frame.filename, frame.lineno, stat.size / (1024 * 1024)))
88+
line = linecache.getline(frame.filename, frame.lineno).strip()
89+
if line:
90+
print(' %s' % line)
91+
92+
other = top_stats[limit:]
93+
if other:
94+
size = sum(stat.size for stat in other)
95+
print("%s other: %.1f MB" % (len(other), size / (1024 * 1024)))
96+
total = sum(stat.size for stat in top_stats)
97+
print(f"Total Python-tracked memory: {total / (1024 * 1024):.2f} MB")
98+
7399

74100
class ModelCompressor:
75101
"""
@@ -362,25 +388,37 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
362388
def compress(
363389
self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
364390
) -> Dict[str, Tensor]:
391+
from torch.profiler import profile, ProfilerActivity
392+
from .track_tensor_memory import TrackTensorAllocations
365393
"""
366394
Compresses a dense state dict or model with sparsity and/or quantization
367395
368396
:param model: uncompressed model to compress
369397
:param state_dict: optional uncompressed state_dict to insert into model
370398
:return: compressed state dict
371399
"""
400+
372401
if state_dict is None:
373402
state_dict = model.state_dict()
374403

375404
if self.quantization_compressor is not None:
376-
module_to_scheme = map_module_to_scheme(model)
377-
state_dict = self.quantization_compressor.compress(
378-
state_dict, names_to_scheme=module_to_scheme
379-
)
380-
if self.quantization_config.format != CompressionFormat.dense.value:
381-
self.quantization_config.quantization_status = (
382-
QuantizationStatus.COMPRESSED
405+
#with profile(activities=[ProfilerActivity.CUDA], profile_memory=True, record_shapes=True, with_stack=True) as prof:
406+
with TrackTensorAllocations() as prof:
407+
module_to_scheme = map_module_to_scheme(model)
408+
state_dict = self.quantization_compressor.compress(
409+
state_dict, names_to_scheme=module_to_scheme
383410
)
411+
print(prof.total_tensor_memory_mib)
412+
breakpoint()
413+
# if self.quantization_config.format != CompressionFormat.dense.value:
414+
# self.quantization_config.quantization_status = (
415+
# QuantizationStatus.COMPRESSED
416+
# )
417+
418+
#prof.export_memory_timeline("memory.html")
419+
#print(prof.key_averages().table(sort_by="self_device_memory_usage", row_limit=3))
420+
#breakpoint()
421+
return state_dict
384422

385423
if self.sparsity_compressor is not None:
386424
sparse_compression_targets: Set[str] = expand_target_names(
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Callable, Any, Type, List, Set
2+
from functools import partial
3+
4+
import gc
5+
import torch
6+
import weakref
7+
8+
9+
class TrackTensorAllocations:
10+
total_tensor_memory: int
11+
memory_timeline: List[int]
12+
13+
_tracked_tensors: Set[int]
14+
_original_init_fn: Callable[[Any], None]
15+
16+
def __init__(self):
17+
self.total_tensor_memory = 0
18+
self.memory_timeline = []
19+
20+
self._tracked_tensors = set()
21+
self._original_init_fn = torch.Tensor.__init__
22+
23+
def __enter__(self):
24+
def wrapped_init(instance, *args, **kwargs):
25+
if isinstance(instance, torch.Tensor):
26+
self._original_init_fn(instance)
27+
self.track_tensor(instance)
28+
else:
29+
# parameters, ect.
30+
type(instance).__init__(instance, *args, **kwargs)
31+
32+
torch.Tensor.__init__ = wrapped_init
33+
34+
return self
35+
36+
def __exit__(self, exc_type, exc_val, exc_tb):
37+
torch.Tensor.__init__ = self._original_init_fn
38+
self._active = False
39+
gc.collect()
40+
41+
def track_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
42+
tensor_hash = hash(tensor)
43+
tensor_memory = tensor.numel() * tensor.element_size()
44+
45+
# warn when init is called twice
46+
if tensor_hash in self._tracked_tensors:
47+
print("double init")
48+
return
49+
50+
# add memory
51+
self.total_tensor_memory += tensor_memory
52+
self._add_to_timeline()
53+
self._tracked_tensors.add(tensor_hash)
54+
55+
# register hook to subtract memory
56+
weakref.finalize(tensor, partial(self._on_tensor_deallocated, tensor_memory, tensor_hash))
57+
58+
def _on_tensor_deallocated(self, tensor_memory, tensor_hash):
59+
self.total_tensor_memory -= tensor_memory
60+
self._add_to_timeline()
61+
self._tracked_tensors.remove(tensor_hash)
62+
63+
@property
64+
def total_tensor_memory_mib(self):
65+
return self.total_tensor_memory / (1024 * 1024)
66+
67+
def _add_to_timeline(self):
68+
self.memory_timeline.append(self.total_tensor_memory)

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
187187

188188
# convert back to signed and torch
189189
packed = np.ascontiguousarray(packed).view(np.int32)
190-
return torch.from_numpy(packed)
190+
return torch.Tensor(torch.from_numpy(packed))
191191

192192

193193
def unpack_from_int32(

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ def calculate_compression_ratio(model: Module) -> float:
322322
:param model: pytorch module to calculate compression ratio for
323323
:return: compression ratio of the whole model
324324
"""
325+
return 0.0
325326
total_compressed = 0.0
326327
total_uncompressed = 0.0
327328
for name, submodule in tqdm(

0 commit comments

Comments
 (0)