|
65 | 65 |
|
66 | 66 | _LOGGER: logging.Logger = logging.getLogger(__name__)
|
67 | 67 |
|
| 68 | +import tracemalloc |
| 69 | +import linecache |
| 70 | +import objgraph |
68 | 71 |
|
69 | 72 | if TYPE_CHECKING:
|
70 | 73 | # dummy type if not available from transformers
|
71 | 74 | CompressedTensorsConfig = TypeVar("CompressedTensorsConfig")
|
72 | 75 |
|
| 76 | +def display_top(snapshot, key_type='lineno', limit=3): |
| 77 | + snapshot = snapshot.filter_traces(( |
| 78 | + tracemalloc.Filter(False, "<frozen importlib._bootstrap>"), |
| 79 | + tracemalloc.Filter(False, "<unknown>"), |
| 80 | + )) |
| 81 | + top_stats = snapshot.statistics(key_type) |
| 82 | + |
| 83 | + print("Top %s lines" % limit) |
| 84 | + for index, stat in enumerate(top_stats[:limit], 1): |
| 85 | + frame = stat.traceback[0] |
| 86 | + print("#%s: %s:%s: %.1f MB" |
| 87 | + % (index, frame.filename, frame.lineno, stat.size / (1024 * 1024))) |
| 88 | + line = linecache.getline(frame.filename, frame.lineno).strip() |
| 89 | + if line: |
| 90 | + print(' %s' % line) |
| 91 | + |
| 92 | + other = top_stats[limit:] |
| 93 | + if other: |
| 94 | + size = sum(stat.size for stat in other) |
| 95 | + print("%s other: %.1f MB" % (len(other), size / (1024 * 1024))) |
| 96 | + total = sum(stat.size for stat in top_stats) |
| 97 | + print(f"Total Python-tracked memory: {total / (1024 * 1024):.2f} MB") |
| 98 | + |
73 | 99 |
|
74 | 100 | class ModelCompressor:
|
75 | 101 | """
|
@@ -362,25 +388,37 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
|
362 | 388 | def compress(
|
363 | 389 | self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
|
364 | 390 | ) -> Dict[str, Tensor]:
|
| 391 | + from torch.profiler import profile, ProfilerActivity |
| 392 | + from .track_tensor_memory import TrackTensorAllocations |
365 | 393 | """
|
366 | 394 | Compresses a dense state dict or model with sparsity and/or quantization
|
367 | 395 |
|
368 | 396 | :param model: uncompressed model to compress
|
369 | 397 | :param state_dict: optional uncompressed state_dict to insert into model
|
370 | 398 | :return: compressed state dict
|
371 | 399 | """
|
| 400 | + |
372 | 401 | if state_dict is None:
|
373 | 402 | state_dict = model.state_dict()
|
374 | 403 |
|
375 | 404 | if self.quantization_compressor is not None:
|
376 |
| - module_to_scheme = map_module_to_scheme(model) |
377 |
| - state_dict = self.quantization_compressor.compress( |
378 |
| - state_dict, names_to_scheme=module_to_scheme |
379 |
| - ) |
380 |
| - if self.quantization_config.format != CompressionFormat.dense.value: |
381 |
| - self.quantization_config.quantization_status = ( |
382 |
| - QuantizationStatus.COMPRESSED |
| 405 | + #with profile(activities=[ProfilerActivity.CUDA], profile_memory=True, record_shapes=True, with_stack=True) as prof: |
| 406 | + with TrackTensorAllocations() as prof: |
| 407 | + module_to_scheme = map_module_to_scheme(model) |
| 408 | + state_dict = self.quantization_compressor.compress( |
| 409 | + state_dict, names_to_scheme=module_to_scheme |
383 | 410 | )
|
| 411 | + print(prof.total_tensor_memory_mib) |
| 412 | + breakpoint() |
| 413 | + # if self.quantization_config.format != CompressionFormat.dense.value: |
| 414 | + # self.quantization_config.quantization_status = ( |
| 415 | + # QuantizationStatus.COMPRESSED |
| 416 | + # ) |
| 417 | + |
| 418 | + #prof.export_memory_timeline("memory.html") |
| 419 | + #print(prof.key_averages().table(sort_by="self_device_memory_usage", row_limit=3)) |
| 420 | + #breakpoint() |
| 421 | + return state_dict |
384 | 422 |
|
385 | 423 | if self.sparsity_compressor is not None:
|
386 | 424 | sparse_compression_targets: Set[str] = expand_target_names(
|
|
0 commit comments