Skip to content

[Performance] Reduce compression memory requirements via structure change #301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
c1b06de
reduce memory requirements, clarify map_modules_to_quant_scheme
kylesayrs Apr 2, 2025
48fbbd8
rename to map_module_to_scheme
kylesayrs Apr 2, 2025
90cf72d
style
kylesayrs Apr 2, 2025
783a081
update docstring
kylesayrs Apr 2, 2025
735660d
remove todo
kylesayrs Apr 2, 2025
6017f05
update docstring
kylesayrs Apr 2, 2025
9524c7f
update marlin test, marlin uses scheme
kylesayrs Apr 2, 2025
ac27709
fix tests
kylesayrs Apr 3, 2025
72dd867
wip
kylesayrs Apr 14, 2025
2fc0403
replacement shows reduction
kylesayrs Apr 17, 2025
37da099
clean up
kylesayrs Apr 21, 2025
e3359eb
Merge remote-tracking branch 'origin' into kylesayrs/reduce-quantized…
kylesayrs Apr 21, 2025
b5374ae
fix merge
kylesayrs Apr 21, 2025
dfef94d
Merge remote-tracking branch 'origin' into kylesayrs/reduce-quantized…
kylesayrs Apr 21, 2025
700d4b6
fix test
kylesayrs Apr 21, 2025
8ae9004
fix tests
kylesayrs Apr 21, 2025
ab70311
remove unneeded code
kylesayrs Apr 22, 2025
ca600ca
fix typos
kylesayrs Apr 22, 2025
97bda13
use map_module_to_scheme, _should_save_zp
kylesayrs Apr 22, 2025
6060bbe
remove unused import
kylesayrs Apr 22, 2025
d4e96d1
Merge branch 'kylesayrs/map_module_to_scheme' into kylesayrs/reduce-q…
kylesayrs Apr 22, 2025
b1d384f
allow get_execution_device to be used when initializing a model
kylesayrs Apr 22, 2025
436929a
formatting
kylesayrs Apr 22, 2025
43736a9
rename to module_map_replace
kylesayrs Apr 22, 2025
9e82ddb
remove unused imports
kylesayrs Apr 22, 2025
f324af7
rename to _skip_zp
kylesayrs Apr 22, 2025
d4affd4
Merge branch 'kylesayrs/map_module_to_scheme' into kylesayrs/reduce-q…
kylesayrs Apr 22, 2025
f2898df
Merge branch 'kylesayrs/get_execution_device-meta' into kylesayrs/red…
kylesayrs Apr 22, 2025
0272c1c
add unwrapping, tests
kylesayrs Apr 28, 2025
1862e0f
Merge remote-tracking branch 'origin' into kylesayrs/reduce-quantized…
kylesayrs Apr 28, 2025
3ac19fa
Merge remote-tracking branch 'origin' into kylesayrs/reduce-quantized…
kylesayrs Apr 28, 2025
16f9f1f
don't use compressedlinear
kylesayrs May 2, 2025
25e1ec3
Merge remote-tracking branch 'origin' into kylesayrs/reduce-quantized…
kylesayrs May 2, 2025
ef4dd02
cleanup
kylesayrs May 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.sparse_compressors import DenseCompressor
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization import (
DEFAULT_QUANTIZATION_METHOD,
QuantizationConfig,
QuantizationScheme,
QuantizationStatus,
apply_quantization_config,
load_pretrained_quantization_parameters,
unwrap_module_forward_quantized,
)
from compressed_tensors.quantization.lifecycle import expand_target_names
from compressed_tensors.quantization.utils import (
Expand All @@ -50,13 +52,15 @@
get_safetensors_folder,
has_offloaded_params,
merge_names,
module_map_replace,
register_offload_parameter,
update_parameter_data,
)
from compressed_tensors.utils.helpers import (
fix_fsdp_module_name,
is_compressed_tensors_config,
)
from compressed_tensors.utils.offload import disable_hf_hook, update_offload_parameter
from torch import Tensor
from torch.nn import Module
from tqdm import tqdm
Expand Down Expand Up @@ -98,6 +102,9 @@ class ModelCompressor:
:param quantization_config: config specifying quantization compression parameters
"""

sparsity_config: Optional[SparsityCompressionConfig] = None
quantization_config: Optional[QuantizationConfig] = None

@classmethod
def from_pretrained(
cls,
Expand Down Expand Up @@ -362,8 +369,54 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:

return list(unexpected_keys)

def apply_compression_status(self, model: Module):
# sparsity compression
if self.quantization_config is None:
for module in model.modules():
module.quantization_status = QuantizationStatus.COMPRESSED
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we run compression here too, for cases when we only have sparsity!


# hack: compress state dict upfront, since CompressedLinear doesn't have
# support for sparsified models
model_state_dict = self.compress(model)

def state_dict_hook(module, prefix, keep_vars):
return model_state_dict if prefix == "" else {}

model.register_state_dict_pre_hook(state_dict_hook)

return

def replace_with_compressed(module: Module) -> Module:
scheme = getattr(module, "quantization_scheme", None)
if isinstance(module, torch.nn.Linear) and scheme is not None:
# TODO: after refactored into hook, just remove hook
if hasattr(module, "quantization_status"):
with disable_hf_hook(module):
unwrap_module_forward_quantized(module)

state_dict = self.compress(module, show_progress=False)

# CompressedLinear initializes qparams which have to be deleted
# TODO: CompressedLinear should not initialize qparams
for name, _ in list(module.named_parameters()):
delattr(module, name)

for name, value in state_dict.items():
param = torch.nn.Parameter(value, requires_grad=False)
register_offload_parameter(module, name, param)

module.quantization_status = QuantizationStatus.COMPRESSED

return module

progress = tqdm(desc="Compressing modules", total=len(list(model.modules())))
module_map_replace(model, replace_with_compressed, progress=progress)

def compress(
self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
self,
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
show_progress: bool = False,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict or model with sparsity and/or quantization
Expand All @@ -379,7 +432,9 @@ def compress(
if self.quantization_compressor is not None:
module_to_scheme = map_module_to_scheme(model)
state_dict = self.quantization_compressor.compress(
state_dict, names_to_scheme=module_to_scheme
state_dict,
names_to_scheme=module_to_scheme,
show_progress=show_progress,
)

# TODO: consider sparse compression to also be compression
Expand All @@ -397,6 +452,7 @@ def compress(
state_dict = self.sparsity_compressor.compress(
state_dict,
compression_targets=sparse_compression_targets,
show_progress=show_progress,
)

# HACK: Override the dtype_byte_size function in transformers to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def compress(
self,
model_state: Dict[str, Tensor],
names_to_scheme: Dict[str, QuantizationScheme],
show_progress: bool = False,
**kwargs,
) -> Dict[str, Tensor]:
"""
Expand All @@ -79,13 +80,16 @@ def compress(
:param model_state: state dict of uncompressed model
:param names_to_scheme: quantization args for each quantized weight, needed for
quantize function to calculate bit depth
:param show_progress: whether to show tqdm progress
:return: compressed state dict
"""
uncompressed_names = list(model_state.keys())
compressed_dict = {}
save_device = "cpu"

uncompressed_names = list(model_state.keys())
for name in tqdm(uncompressed_names, desc="Compressing with quantization"):
# compress values
desc = "Compressing with quantization"
for name in tqdm(uncompressed_names, desc=desc, disable=(not show_progress)):
value = model_state[name]

# compress weights
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def compress(
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
for name, value in tqdm(model_state.items(), desc="Compressing with sparsity"):
if not self.should_compress(name, compression_targets):
compressed_dict[name] = value
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def compress(
self,
model_state: Dict[str, Tensor],
names_to_scheme: Dict[str, QuantizationScheme],
show_progress: bool = False,
**kwargs,
) -> Dict[str, Tensor]:
"""
Expand All @@ -134,6 +135,7 @@ def compress(
:param model_state: state dict of uncompressed model
:param names_to_scheme: quantization scheme for each quantized weight, needed
for quantize function to calculate bit depth
:param show_progress: whether to show tqdm progress
:return: compressed state dict
"""
self.validate_quant_compatability(names_to_scheme)
Expand All @@ -144,7 +146,9 @@ def compress(
f"Compressing model with {len(model_state)} parameterized layers..."
)

for name, value in tqdm(model_state.items(), desc="Compressing model"):
for name, value in tqdm(
model_state.items(), desc="Compressing model", disable=(not show_progress)
):
if name.endswith(weight_suffix):
prefix = name[: -(len(weight_suffix))]
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
Expand Down
5 changes: 3 additions & 2 deletions src/compressed_tensors/linear/compressed_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
initialize_module_for_quantization,
)
from compressed_tensors.utils import register_offload_parameter
from compressed_tensors.utils.offload import get_execution_device
from torch import Tensor
from torch.nn import Parameter
from torch.nn.functional import linear
Expand Down Expand Up @@ -60,7 +61,7 @@ def from_linear(
"""
module.__class__ = CompressedLinear
module.compressor = BaseCompressor.load_from_registry(quantization_format)
device = next(module.parameters()).device
init_device = get_execution_device(module)

# this will initialize all the scales and zero points
initialize_module_for_quantization(
Expand All @@ -79,7 +80,7 @@ def from_linear(
# populate compressed weights and quantization parameters
for name, (shape, dtype) in compression_params.items():
param = Parameter(
torch.empty(shape, device=device, dtype=dtype), requires_grad=False
torch.empty(shape, device=init_device, dtype=dtype), requires_grad=False
)
register_offload_parameter(module, name, param)

Expand Down
5 changes: 5 additions & 0 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"dequantize",
"fake_quantize",
"wrap_module_forward_quantized",
"unwrap_module_forward_quantized",
"forward_quantize",
]

Expand Down Expand Up @@ -312,6 +313,10 @@ def wrapped_forward(self, *args, **kwargs):
setattr(module, "forward", bound_wrapped_forward)


def unwrap_module_forward_quantized(module: Module):
delattr(module, "forward") # revert to class implementation
Comment on lines +316 to +317
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this then expose the forward method on the Linear or CompressedLinear class? why do we want to delete the attr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this exposes the forward method of the class implementation



def forward_quantize(
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
) -> torch.Tensor:
Expand Down
43 changes: 42 additions & 1 deletion src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

import warnings
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import numpy
import torch
import tqdm
from transformers import AutoConfig


Expand All @@ -39,6 +40,7 @@
"pack_bitmasks",
"unpack_bitmasks",
"remove_suffix",
"module_map_replace",
]

FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
Expand Down Expand Up @@ -335,3 +337,42 @@ def remove_suffix(value: str, suffix: str) -> str:
# can replace with str.removesuffix in python3.9+
assert value.endswith(suffix)
return value[: -len(suffix)]


def module_map_replace(
module: torch.nn.Module,
func: Callable[[torch.nn.Module], torch.nn.Module],
progress: Union[bool, tqdm.tqdm] = False,
pre: bool = True,
) -> torch.nn.Module:
"""
Replaces modules in a given `torch.nn.Module` recursively using a provided function.

This function traverses the module hierarchy and applies the `func` transformation
either before (`pre=True`) or after (`pre=False`) recursing into children modules.
Optionally displays progress using tqdm.

:param module: root module to replace
:param func: module mapping function
:param progress: if True, display a tqdm progress bar.
If a `tqdm.tqdm` instance is provided, the instance will be updated
:param pre: if True, apply with pre-order, post-order otherwise
:return: the modified module after applying the function to all submodules
"""
if progress is True:
total = len(list(module.modules()))
progress = tqdm.tqdm(total=total)

if pre:
module = func(module)

for name, child in list(module.named_children()):
module.add_module(name, module_map_replace(child, func, pre, progress))

if not pre:
module = func(module)

if isinstance(progress, tqdm.tqdm):
progress.update(1)

return module
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import torch.nn as nn
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationConfig
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization import QuantizationConfig, QuantizationStatus
from safetensors.torch import save_file
from tests.testing_utils import induce_sparsity, requires_hf_quantizer
from transformers import AutoModelForCausalLM


def sparsity_config():
Expand Down Expand Up @@ -365,3 +367,54 @@ def _get_combined_config(s_config, q_config):
combined["sparsity_config"] = s_config

return combined


@pytest.mark.parametrize(
"model_stub,q_format,s_format",
[
(
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed",
"float-quantized",
None,
),
(
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed",
None,
"sparse-24-bitmask",
),
(
"nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed",
"float-quantized",
"sparse-24-bitmask",
),
],
)
def test_apply_compression_status(model_stub, q_format, s_format):
model = AutoModelForCausalLM.from_pretrained(model_stub)
compressor = ModelCompressor.from_pretrained_model(model, s_format, q_format)
original_compressed_state_dict = dict(compressor.compress(model))
original_compressed_state_dict = {
key: value.clone() for key, value in original_compressed_state_dict.items()
}

compressor.apply_compression_status(model)

for module in model.modules():
# scheme <=> CompressedLinear
has_scheme = hasattr(module, "quantization_scheme")
is_compressed = (
getattr(module, "quantization_status", None)
== QuantizationStatus.COMPRESSED
)
# assert has_scheme == is_compressed

# equivalent to eagerly compressing state dict
compressed_state_dict = dict(model.state_dict())
assert compressed_state_dict.keys() == original_compressed_state_dict.keys()
for key in compressed_state_dict.keys():
assert torch.all(
compressed_state_dict[key] == original_compressed_state_dict[key]
), f"{key}"

# can run to completion
# model(**model.dummy_inputs)
Loading