Skip to content

Commit 0272c1c

Browse files
committed
add unwrapping, tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent f2898df commit 0272c1c

File tree

4 files changed

+62
-4
lines changed

4 files changed

+62
-4
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@
3333
from compressed_tensors.compressors.base import BaseCompressor
3434
from compressed_tensors.compressors.sparse_compressors import DenseCompressor
3535
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
36+
from compressed_tensors.linear.compressed_linear import CompressedLinear
3637
from compressed_tensors.quantization import (
3738
DEFAULT_QUANTIZATION_METHOD,
3839
QuantizationConfig,
3940
QuantizationScheme,
4041
QuantizationStatus,
4142
apply_quantization_config,
4243
load_pretrained_quantization_parameters,
44+
unwrap_module_forward_quantized,
4345
)
4446
from compressed_tensors.quantization.lifecycle import expand_target_names
4547
from compressed_tensors.quantization.utils import (
@@ -58,7 +60,7 @@
5860
fix_fsdp_module_name,
5961
is_compressed_tensors_config,
6062
)
61-
from compressed_tensors.utils.offload import update_offload_parameter
63+
from compressed_tensors.utils.offload import disable_hf_hook, update_offload_parameter
6264
from torch import Tensor
6365
from torch.nn import Module
6466
from tqdm import tqdm
@@ -100,6 +102,9 @@ class ModelCompressor:
100102
:param quantization_config: config specifying quantization compression parameters
101103
"""
102104

105+
sparsity_config: Optional[SparsityCompressionConfig] = None
106+
quantization_config: Optional[QuantizationConfig] = None
107+
103108
@classmethod
104109
def from_pretrained(
105110
cls,
@@ -364,12 +369,22 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
364369

365370
return list(unexpected_keys)
366371

367-
def apply_compression_status(self, model: Module) -> Module:
372+
def apply_compression_status(self, model: Module):
373+
if self.quantization_config is None:
374+
for module in model.modules():
375+
module.quantization_status = QuantizationStatus.COMPRESSED
376+
return
377+
368378
quantization_format = self.quantization_config.format
369379

370380
def replace_with_compressed(module: Module) -> Module:
371381
scheme = getattr(module, "quantization_scheme", None)
372382
if isinstance(module, torch.nn.Linear) and scheme is not None:
383+
# TODO: after refactored into hook, just remove hook
384+
if hasattr(module, "quantization_status"):
385+
with disable_hf_hook(module):
386+
unwrap_module_forward_quantized(module)
387+
373388
module = CompressedLinear.from_linear(
374389
module,
375390
quantization_scheme=scheme,
@@ -385,7 +400,7 @@ def replace_with_compressed(module: Module) -> Module:
385400
return module
386401

387402
progress = tqdm(total=len(list(model.modules())))
388-
return module_map_replace(model, replace_with_compressed, progress=progress)
403+
module_map_replace(model, replace_with_compressed, progress=progress)
389404

390405
def compress(
391406
self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"dequantize",
3838
"fake_quantize",
3939
"wrap_module_forward_quantized",
40+
"unwrap_module_forward_quantized",
4041
"forward_quantize",
4142
]
4243

@@ -312,6 +313,10 @@ def wrapped_forward(self, *args, **kwargs):
312313
setattr(module, "forward", bound_wrapped_forward)
313314

314315

316+
def unwrap_module_forward_quantized(module: Module):
317+
delattr(module, "forward") # revert to class implementation
318+
319+
315320
def forward_quantize(
316321
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
317322
) -> torch.Tensor:

src/compressed_tensors/utils/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414

1515
import warnings
1616
from functools import wraps
17-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
17+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
1818

1919
import numpy
2020
import torch
21+
import tqdm
2122
from transformers import AutoConfig
2223

2324

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import torch.nn as nn
2222
from compressed_tensors.compressors import ModelCompressor
2323
from compressed_tensors.config import SparsityCompressionConfig
24+
from compressed_tensors.linear.compressed_linear import CompressedLinear
2425
from compressed_tensors.quantization import QuantizationConfig
2526
from safetensors.torch import save_file
2627
from tests.testing_utils import induce_sparsity, requires_hf_quantizer
28+
from transformers import AutoModelForCausalLM
2729

2830

2931
def sparsity_config():
@@ -365,3 +367,38 @@ def _get_combined_config(s_config, q_config):
365367
combined["sparsity_config"] = s_config
366368

367369
return combined
370+
371+
372+
@pytest.mark.parametrize(
373+
"model_stub,q_format,s_format",
374+
[
375+
(
376+
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed",
377+
"float-quantized",
378+
None,
379+
),
380+
(
381+
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed",
382+
None,
383+
"sparse-24-bitmask",
384+
),
385+
(
386+
"nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed",
387+
"float-quantized",
388+
"sparse-24-bitmask",
389+
),
390+
],
391+
)
392+
def test_apply_compression_status(model_stub, q_format, s_format):
393+
model = AutoModelForCausalLM.from_pretrained(model_stub)
394+
compressor = ModelCompressor.from_pretrained_model(model, s_format, q_format)
395+
compressor.apply_compression_status(model)
396+
397+
for module in model.modules():
398+
# scheme <=> CompressedLinear
399+
has_scheme = hasattr(module, "quantization_scheme")
400+
is_compressed = isinstance(module, CompressedLinear)
401+
assert has_scheme == is_compressed
402+
403+
# can run to completion
404+
model(**model.dummy_inputs)

0 commit comments

Comments
 (0)