Skip to content

Commit 2fc0403

Browse files
committed
replacement shows reduction
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 72dd867 commit 2fc0403

File tree

4 files changed

+100
-51
lines changed

4 files changed

+100
-51
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 62 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
import re
2020
from contextlib import contextmanager
2121
from copy import deepcopy
22-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
22+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, Callable
2323

2424
import compressed_tensors
25+
from compressed_tensors.linear.compressed_linear import CompressedLinear
26+
from compressed_tensors.utils.offload import update_offload_parameter
2527
import torch
2628
import transformers
2729
from compressed_tensors.base import (
@@ -65,37 +67,36 @@
6567

6668
_LOGGER: logging.Logger = logging.getLogger(__name__)
6769

68-
import tracemalloc
69-
import linecache
70-
import objgraph
70+
def module_replace_dfs(
71+
module: Module,
72+
func: Callable[[Module], Module],
73+
pre: bool = True,
74+
progress: Union[bool, tqdm] = False,
75+
) -> Module:
76+
if progress is True:
77+
total = len(list(module.modules()))
78+
progress = tqdm(total=total)
79+
80+
if pre:
81+
module = func(module)
82+
83+
for name, child in list(module.named_children()):
84+
module.add_module(name, module_replace_dfs(child, func, pre, progress))
85+
86+
if not pre:
87+
module = func(module)
88+
89+
if isinstance(progress, tqdm):
90+
progress.update(1)
91+
92+
return module
93+
94+
7195

7296
if TYPE_CHECKING:
7397
# dummy type if not available from transformers
7498
CompressedTensorsConfig = TypeVar("CompressedTensorsConfig")
7599

76-
def display_top(snapshot, key_type='lineno', limit=3):
77-
snapshot = snapshot.filter_traces((
78-
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
79-
tracemalloc.Filter(False, "<unknown>"),
80-
))
81-
top_stats = snapshot.statistics(key_type)
82-
83-
print("Top %s lines" % limit)
84-
for index, stat in enumerate(top_stats[:limit], 1):
85-
frame = stat.traceback[0]
86-
print("#%s: %s:%s: %.1f MB"
87-
% (index, frame.filename, frame.lineno, stat.size / (1024 * 1024)))
88-
line = linecache.getline(frame.filename, frame.lineno).strip()
89-
if line:
90-
print(' %s' % line)
91-
92-
other = top_stats[limit:]
93-
if other:
94-
size = sum(stat.size for stat in other)
95-
print("%s other: %.1f MB" % (len(other), size / (1024 * 1024)))
96-
total = sum(stat.size for stat in top_stats)
97-
print(f"Total Python-tracked memory: {total / (1024 * 1024):.2f} MB")
98-
99100

100101
class ModelCompressor:
101102
"""
@@ -384,6 +385,30 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
384385
)
385386

386387
return list(unexpected_keys)
388+
389+
def apply_compression_status(self, model: Module) -> Module:
390+
quantization_format = self.quantization_config.format
391+
392+
def replace_with_compressed(module: Module) -> Module:
393+
scheme = getattr(module, "quantization_scheme", None)
394+
if isinstance(module, torch.nn.Linear) and scheme is not None:
395+
#compressed_state_dict_2 = self.compress(module) # debug
396+
397+
module = CompressedLinear.from_linear(
398+
module,
399+
quantization_scheme=scheme,
400+
quantization_format=quantization_format
401+
)
402+
state_dict = module.compressor.compress(module.state_dict(), {"": scheme}) # added by compressed linear
403+
404+
for name, value in state_dict.items():
405+
update_offload_parameter(module, name, value)
406+
407+
return module
408+
409+
410+
progress = tqdm(total=len(list(model.modules())))
411+
return module_replace_dfs(model, replace_with_compressed, progress=progress)
387412

388413
def compress(
389414
self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
@@ -403,13 +428,11 @@ def compress(
403428

404429
if self.quantization_compressor is not None:
405430
#with profile(activities=[ProfilerActivity.CUDA], profile_memory=True, record_shapes=True, with_stack=True) as prof:
406-
with TrackTensorAllocations() as prof:
407-
module_to_scheme = map_module_to_scheme(model)
408-
state_dict = self.quantization_compressor.compress(
409-
state_dict, names_to_scheme=module_to_scheme
410-
)
411-
print(prof.total_tensor_memory_mib)
412-
breakpoint()
431+
#with TrackTensorAllocations() as prof:
432+
module_to_scheme = map_module_to_scheme(model)
433+
state_dict = self.quantization_compressor.compress(
434+
state_dict, names_to_scheme=module_to_scheme
435+
)
413436
# if self.quantization_config.format != CompressionFormat.dense.value:
414437
# self.quantization_config.quantization_status = (
415438
# QuantizationStatus.COMPRESSED
@@ -559,13 +582,11 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
559582
"""
560583
Returns a dictionary which maps quantized module names to their quantization schemes
561584
"""
562-
quantized_modules_to_args = {}
563-
for name, submodule in iter_named_leaf_modules(model):
564-
if is_module_quantized(submodule):
565-
name = fix_fsdp_module_name(name)
566-
quantized_modules_to_args[name] = submodule.quantization_scheme
567-
568-
return quantized_modules_to_args
585+
return {
586+
fix_fsdp_module_name(name): module.quantization_scheme
587+
for name, module in iter_named_leaf_modules(model)
588+
if is_module_quantized(module)
589+
}
569590

570591

571592
# HACK: Override the dtype_byte_size function in transformers to support float8 types

src/compressed_tensors/compressors/model_compressors/track_tensor_memory.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Callable, Any, Type, List, Set
22
from functools import partial
3+
import matplotlib.pyplot as plt
34

45
import gc
56
import torch
@@ -66,3 +67,25 @@ def total_tensor_memory_mib(self):
6667

6768
def _add_to_timeline(self):
6869
self.memory_timeline.append(self.total_tensor_memory)
70+
71+
72+
def plot_values_over_time(self, dpi=300):
73+
values = self.memory_timeline
74+
"""
75+
Plots a list of float values over time using matplotlib.
76+
77+
Parameters:
78+
values (list of float): The values to plot.
79+
"""
80+
if not values:
81+
print("The list of values is empty.")
82+
return
83+
84+
plt.figure(figsize=(10, 4))
85+
plt.plot(range(len(values)), values, marker='o', linestyle='-')
86+
plt.title("Values Over Time")
87+
plt.xlabel("Time")
88+
plt.ylabel("Value")
89+
plt.grid(True)
90+
plt.tight_layout()
91+
plt.savefig("file.png", dpi=dpi)

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,22 @@ def compress(
8888
value = model_state[name]
8989

9090
# compress weights
91-
if name.endswith(".weight"):
92-
prefix = remove_suffix(name, ".weight")
91+
if name.endswith("weight"):
92+
prefix = remove_suffix(name, "weight")
9393

9494
# gather qparams
95-
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
96-
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
97-
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
95+
scale = model_state.get(prefix + "weight_scale", None)
96+
g_idx = model_state.get(prefix + "weight_g_idx", None)
97+
zp = model_state.get(prefix + "weight_zero_point", None)
9898

9999
# is scale does not exist, then weight cannot be compressed
100100
if scale is None:
101101
model_state[name] = value.to(save_device)
102102
continue
103103

104104
# compress values on cpu (memory movement too expensive)
105-
quant_args = names_to_scheme[prefix].weights
105+
module_path = prefix[:-1] if prefix.endswith(".") else prefix
106+
quant_args = names_to_scheme[module_path].weights
106107
compressed_values = self.compress_weight(
107108
weight=value,
108109
scale=scale,
@@ -115,7 +116,7 @@ def compress(
115116
# update state dict
116117
del model_state[name]
117118
for key, value in compressed_values.items():
118-
model_state[merge_names(prefix, key)] = value.to(save_device)
119+
model_state[prefix + key] = value.to(save_device)
119120

120121
else:
121122
# omit saving zero points for symmetric quantization
@@ -202,7 +203,10 @@ def _decompress_from_state_dict(
202203

203204

204205
def _is_symmetric(name: str, names_to_scheme: Dict[str, QuantizationScheme]) -> bool:
205-
weight_name, zp_name = name.rsplit(".", 1)
206+
try:
207+
weight_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
208+
except:
209+
breakpoint()
206210
scheme = names_to_scheme[weight_name]
207211

208212
if zp_name == "weight_zero_point":

src/compressed_tensors/linear/compressed_linear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import warnings
1616
from typing import Dict, Tuple
1717

18+
from compressed_tensors.utils.offload import get_execution_device
1819
import torch
1920
from compressed_tensors.compressors.base import BaseCompressor
2021
from compressed_tensors.quantization import (
@@ -60,7 +61,7 @@ def from_linear(
6061
"""
6162
module.__class__ = CompressedLinear
6263
module.compressor = BaseCompressor.load_from_registry(quantization_format)
63-
device = next(module.parameters()).device
64+
init_device = get_execution_device(module)
6465

6566
# this will initialize all the scales and zero points
6667
initialize_module_for_quantization(
@@ -79,7 +80,7 @@ def from_linear(
7980
# populate compressed weights and quantization parameters
8081
for name, (shape, dtype) in compression_params.items():
8182
param = Parameter(
82-
torch.empty(shape, device=device, dtype=dtype), requires_grad=False
83+
torch.empty(shape, device=init_device, dtype=dtype), requires_grad=False
8384
)
8485
register_offload_parameter(module, name, param)
8586

0 commit comments

Comments
 (0)