Skip to content

Commit 37da099

Browse files
committed
clean up
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 2fc0403 commit 37da099

File tree

7 files changed

+65
-148
lines changed

7 files changed

+65
-148
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 27 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,20 @@
1919
import re
2020
from contextlib import contextmanager
2121
from copy import deepcopy
22-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, Callable
22+
from typing import (
23+
TYPE_CHECKING,
24+
Any,
25+
Callable,
26+
Dict,
27+
List,
28+
Optional,
29+
Set,
30+
Tuple,
31+
TypeVar,
32+
Union,
33+
)
2334

2435
import compressed_tensors
25-
from compressed_tensors.linear.compressed_linear import CompressedLinear
26-
from compressed_tensors.utils.offload import update_offload_parameter
2736
import torch
2837
import transformers
2938
from compressed_tensors.base import (
@@ -34,6 +43,7 @@
3443
)
3544
from compressed_tensors.compressors.base import BaseCompressor
3645
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
46+
from compressed_tensors.linear.compressed_linear import CompressedLinear
3747
from compressed_tensors.quantization import (
3848
DEFAULT_QUANTIZATION_METHOD,
3949
QuantizationConfig,
@@ -50,12 +60,14 @@
5060
from compressed_tensors.utils import (
5161
get_safetensors_folder,
5262
merge_names,
63+
module_replace_dfs,
5364
update_parameter_data,
5465
)
5566
from compressed_tensors.utils.helpers import (
5667
fix_fsdp_module_name,
5768
is_compressed_tensors_config,
5869
)
70+
from compressed_tensors.utils.offload import update_offload_parameter
5971
from torch import Tensor
6072
from torch.nn import Module
6173
from tqdm import tqdm
@@ -67,31 +79,6 @@
6779

6880
_LOGGER: logging.Logger = logging.getLogger(__name__)
6981

70-
def module_replace_dfs(
71-
module: Module,
72-
func: Callable[[Module], Module],
73-
pre: bool = True,
74-
progress: Union[bool, tqdm] = False,
75-
) -> Module:
76-
if progress is True:
77-
total = len(list(module.modules()))
78-
progress = tqdm(total=total)
79-
80-
if pre:
81-
module = func(module)
82-
83-
for name, child in list(module.named_children()):
84-
module.add_module(name, module_replace_dfs(child, func, pre, progress))
85-
86-
if not pre:
87-
module = func(module)
88-
89-
if isinstance(progress, tqdm):
90-
progress.update(1)
91-
92-
return module
93-
94-
9582

9683
if TYPE_CHECKING:
9784
# dummy type if not available from transformers
@@ -385,36 +372,35 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
385372
)
386373

387374
return list(unexpected_keys)
388-
375+
389376
def apply_compression_status(self, model: Module) -> Module:
390377
quantization_format = self.quantization_config.format
391378

392379
def replace_with_compressed(module: Module) -> Module:
393380
scheme = getattr(module, "quantization_scheme", None)
394381
if isinstance(module, torch.nn.Linear) and scheme is not None:
395-
#compressed_state_dict_2 = self.compress(module) # debug
382+
# compressed_state_dict_2 = self.compress(module) # debug
396383

397384
module = CompressedLinear.from_linear(
398385
module,
399386
quantization_scheme=scheme,
400-
quantization_format=quantization_format
387+
quantization_format=quantization_format,
401388
)
402-
state_dict = module.compressor.compress(module.state_dict(), {"": scheme}) # added by compressed linear
389+
state_dict = module.compressor.compress(
390+
module.state_dict(), {"": scheme}
391+
) # added by compressed linear
403392

404393
for name, value in state_dict.items():
405394
update_offload_parameter(module, name, value)
406395

407396
return module
408397

409-
410398
progress = tqdm(total=len(list(model.modules())))
411399
return module_replace_dfs(model, replace_with_compressed, progress=progress)
412400

413401
def compress(
414402
self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
415403
) -> Dict[str, Tensor]:
416-
from torch.profiler import profile, ProfilerActivity
417-
from .track_tensor_memory import TrackTensorAllocations
418404
"""
419405
Compresses a dense state dict or model with sparsity and/or quantization
420406
@@ -427,21 +413,16 @@ def compress(
427413
state_dict = model.state_dict()
428414

429415
if self.quantization_compressor is not None:
430-
#with profile(activities=[ProfilerActivity.CUDA], profile_memory=True, record_shapes=True, with_stack=True) as prof:
431-
#with TrackTensorAllocations() as prof:
432416
module_to_scheme = map_module_to_scheme(model)
433417
state_dict = self.quantization_compressor.compress(
434418
state_dict, names_to_scheme=module_to_scheme
435419
)
436-
# if self.quantization_config.format != CompressionFormat.dense.value:
437-
# self.quantization_config.quantization_status = (
438-
# QuantizationStatus.COMPRESSED
439-
# )
440-
441-
#prof.export_memory_timeline("memory.html")
442-
#print(prof.key_averages().table(sort_by="self_device_memory_usage", row_limit=3))
443-
#breakpoint()
444-
return state_dict
420+
421+
# TODO: consider sparse compression to also be compression
422+
if self.quantization_config.format != CompressionFormat.dense.value:
423+
self.quantization_config.quantization_status = (
424+
QuantizationStatus.COMPRESSED
425+
)
445426

446427
if self.sparsity_compressor is not None:
447428
sparse_compression_targets: Set[str] = expand_target_names(

src/compressed_tensors/compressors/model_compressors/track_tensor_memory.py

Lines changed: 0 additions & 91 deletions
This file was deleted.

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,12 @@ def compress(
7676
"""
7777
Compresses a dense state dict
7878
79-
:param model_state: state dict of uncompressed model, consumed by compression
79+
:param model_state: state dict of uncompressed model
8080
:param names_to_scheme: quantization args for each quantized weight, needed for
8181
quantize function to calculate bit depth
8282
:return: compressed state dict
8383
"""
84+
compressed_dict = {}
8485
save_device = "cpu"
8586

8687
uncompressed_names = list(model_state.keys())
@@ -98,7 +99,7 @@ def compress(
9899

99100
# is scale does not exist, then weight cannot be compressed
100101
if scale is None:
101-
model_state[name] = value.to(save_device)
102+
compressed_dict[name] = value.to(save_device)
102103
continue
103104

104105
# compress values on cpu (memory movement too expensive)
@@ -116,22 +117,22 @@ def compress(
116117
# update state dict
117118
del model_state[name]
118119
for key, value in compressed_values.items():
119-
model_state[prefix + key] = value.to(save_device)
120+
compressed_dict[prefix + key] = value.to(save_device)
120121

121122
else:
122123
# omit saving zero points for symmetric quantization
123124
if name.endswith("zero_point") and _is_symmetric(name, names_to_scheme):
124-
del model_state[name]
125+
continue
125126

126127
# omit saving for g_idx if uninitialized
127128
# TODO: does this case actually occur?
128129
elif name.endswith("g_idx") and torch.any(value <= -1):
129-
del model_state[name]
130+
continue
130131

131132
else:
132-
model_state[name] = value.to(save_device)
133+
compressed_dict[name] = value.to(save_device)
133134

134-
return model_state
135+
return compressed_dict
135136

136137
def decompress(
137138
self,

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
187187

188188
# convert back to signed and torch
189189
packed = np.ascontiguousarray(packed).view(np.int32)
190-
return torch.Tensor(torch.from_numpy(packed))
190+
return torch.from_numpy(packed)
191191

192192

193193
def unpack_from_int32(

src/compressed_tensors/linear/compressed_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import warnings
1616
from typing import Dict, Tuple
1717

18-
from compressed_tensors.utils.offload import get_execution_device
1918
import torch
2019
from compressed_tensors.compressors.base import BaseCompressor
2120
from compressed_tensors.quantization import (
@@ -24,6 +23,7 @@
2423
initialize_module_for_quantization,
2524
)
2625
from compressed_tensors.utils import register_offload_parameter
26+
from compressed_tensors.utils.offload import get_execution_device
2727
from torch import Tensor
2828
from torch.nn import Parameter
2929
from torch.nn.functional import linear

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ def calculate_compression_ratio(model: Module) -> float:
322322
:param model: pytorch module to calculate compression ratio for
323323
:return: compression ratio of the whole model
324324
"""
325-
return 0.0
326325
total_compressed = 0.0
327326
total_uncompressed = 0.0
328327
for name, submodule in tqdm(

src/compressed_tensors/utils/helpers.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414

1515
import warnings
1616
from functools import wraps
17-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
17+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
1818

1919
import numpy
2020
import torch
21+
import tqdm
2122
from transformers import AutoConfig
2223

2324

@@ -39,6 +40,7 @@
3940
"pack_bitmasks",
4041
"unpack_bitmasks",
4142
"remove_suffix",
43+
"module_replace_dfs",
4244
]
4345

4446
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -335,3 +337,28 @@ def remove_suffix(value: str, suffix: str) -> str:
335337
# can replace with str.removesuffix in python3.9+
336338
assert value.endswith(suffix)
337339
return value[: -len(suffix)]
340+
341+
342+
def module_replace_dfs(
343+
module: torch.nn.Module,
344+
func: Callable[[torch.nn.Module], torch.nn.Module],
345+
pre: bool = True,
346+
progress: Union[bool, tqdm.tqdm] = False,
347+
) -> torch.nn.Module:
348+
if progress is True:
349+
total = len(list(module.modules()))
350+
progress = tqdm.tqdm(total=total)
351+
352+
if pre:
353+
module = func(module)
354+
355+
for name, child in list(module.named_children()):
356+
module.add_module(name, module_replace_dfs(child, func, pre, progress))
357+
358+
if not pre:
359+
module = func(module)
360+
361+
if isinstance(progress, tqdm.tqdm):
362+
progress.update(1)
363+
364+
return module

0 commit comments

Comments
 (0)