Skip to content

Commit 7ab17ce

Browse files
committed
Merge branch 'main' into kylesayrs/transform_permutations
2 parents 4ae491d + da19b0f commit 7ab17ce

File tree

5 files changed

+45
-21
lines changed

5 files changed

+45
-21
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@
6868
from transformers.file_utils import CONFIG_NAME
6969

7070

71+
if TYPE_CHECKING:
72+
from compressed_tensors.compressors import BaseQuantizationCompressor
73+
74+
7175
__all__ = ["ModelCompressor", "map_module_to_scheme"]
7276

7377
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -257,7 +261,9 @@ def __init__(
257261
self.sparsity_config = sparsity_config
258262
self.quantization_config = quantization_config
259263
self.sparsity_compressor = None
260-
self.quantization_compressor = None
264+
self.quantization_compressor: Optional[
265+
Union[BaseQuantizationCompressor, DenseCompressor]
266+
] = None
261267

262268
if sparsity_config is not None:
263269
self.sparsity_compressor = BaseCompressor.load_from_registry(

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Dict, Generator, Optional, Set, Tuple
16+
from typing import TYPE_CHECKING, Dict, Generator, Optional, Set, Tuple
1717

18+
import torch
1819
from compressed_tensors.compressors.base import BaseCompressor
1920
from compressed_tensors.utils import (
2021
get_nested_mappings_from_state_dict,
@@ -26,6 +27,10 @@
2627
from tqdm import tqdm
2728

2829

30+
if TYPE_CHECKING:
31+
from compressed_tensors.quantization import QuantizationScheme
32+
33+
2934
__all__ = ["BaseSparseCompressor"]
3035

3136
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -200,3 +205,16 @@ def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> b
200205
return (
201206
name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
202207
)
208+
209+
def decompress_module_from_state_dict(
210+
self,
211+
prefix: str,
212+
state_dict: Dict[str, torch.Tensor],
213+
scheme: "QuantizationScheme",
214+
) -> Dict[str, torch.Tensor]:
215+
"""
216+
This function is implemented as a workaround because of how
217+
`ModelCompressor.quantization_compressor` can be set to either
218+
an instance of `BaseQuantizationCompressor` or `BaseSparseCompressor`.
219+
"""
220+
return state_dict.copy()

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,12 +331,10 @@ def find_name_or_class_matches(
331331
3. matches on module names
332332
"""
333333
targets = sorted(targets, key=lambda x: ("re:" in x, x))
334-
if isinstance(targets, Iterable):
335-
matches = _find_matches(name, targets) + _find_matches(
336-
module.__class__.__name__, targets, check_contains
337-
)
338-
matches = [match for match in matches if match is not None]
339-
return matches
334+
matches = _find_matches(name, targets) + _find_matches(
335+
module.__class__.__name__, targets, check_contains
336+
)
337+
return [match for match in matches if match is not None]
340338

341339

342340
def _find_matches(

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,11 @@
2525
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
2626
def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
2727
"""
28-
Construct an Hadamard matrix.
28+
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
29+
`n` must be a power of 2.
2930
30-
Constructs an n-by-n Hadamard matrix, using Sylvester's
31-
construction. `n` must be a power of 2.
32-
33-
:param size: order of the matrix; must be a power of 2
34-
35-
returns a (size, size) hadamard matrix
31+
:param size: order of the matrix, must be a power of 2
32+
:return: hadamard matrix of size `size`
3633
"""
3734
if size <= 0:
3835
raise ValueError("Cannot construct deterministic hadamard of size <= 0")

src/compressed_tensors/utils/offload.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,18 @@ def decorator(func: Callable[[Any], Any]):
8787
if not _has_accelerate:
8888

8989
if fallback == "error":
90-
raise ValueError(
91-
"Please install `accelerate` in order to use this function"
92-
)
9390

94-
@wraps(func)
95-
def fallback_fn(*args, **kwargs):
96-
return fallback
91+
@wraps(func)
92+
def fallback_fn(*args, **kwargs):
93+
raise ValueError(
94+
"Please install `accelerate` in order to use this function"
95+
)
96+
97+
else:
98+
99+
@wraps(func)
100+
def fallback_fn(*args, **kwargs):
101+
return fallback
97102

98103
return fallback_fn
99104

0 commit comments

Comments
 (0)