From 2846b6a23c6053a1f83c34a721e9b56254548f7a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 28 Apr 2025 13:30:46 -0400 Subject: [PATCH] Revert "Enable module state_dict compression, simplify compression logic (#302)" This reverts commit 4438d08d7573b26313763250d0497d4ac0974f47. --- .../model_compressors/model_compressor.py | 54 ++++-- .../compressors/quantized_compressors/base.py | 161 +++++++++--------- .../sparse_quantized_compressors/marlin_24.py | 34 ++-- src/compressed_tensors/utils/helpers.py | 7 - .../quantized_compressors/test_fp8_quant.py | 14 +- .../quantized_compressors/test_int_quant.py | 18 +- .../quantized_compressors/test_pack_quant.py | 38 +++-- .../test_marlin_24.py | 8 +- 8 files changed, 165 insertions(+), 169 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 13a9f375..7a7a5e88 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -19,7 +19,7 @@ import re from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union import compressed_tensors import torch @@ -36,12 +36,12 @@ from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_METHOD, QuantizationConfig, - QuantizationScheme, QuantizationStatus, apply_quantization_config, load_pretrained_quantization_parameters, ) from compressed_tensors.quantization.lifecycle import expand_target_names +from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.utils import ( is_module_quantized, iter_named_leaf_modules, @@ -64,7 +64,7 @@ from transformers.file_utils import CONFIG_NAME -__all__ = ["ModelCompressor", "map_module_to_scheme"] +__all__ = ["ModelCompressor", "map_modules_to_quant_args"] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -372,17 +372,20 @@ def compress( :param state_dict: optional uncompressed state_dict to insert into model :return: compressed state dict """ - if state_dict is None: state_dict = model.state_dict() + compressed_state_dict = state_dict + + quantized_modules_to_args: Dict[ + str, QuantizationArgs + ] = map_modules_to_quant_args(model) + if self.quantization_compressor is not None: - module_to_scheme = map_module_to_scheme(model) - state_dict = self.quantization_compressor.compress( - state_dict, names_to_scheme=module_to_scheme + compressed_state_dict = self.quantization_compressor.compress( + state_dict, names_to_scheme=quantized_modules_to_args ) - # TODO: consider sparse compression to also be compression if self.quantization_config.format != CompressionFormat.dense.value: self.quantization_config.quantization_status = ( QuantizationStatus.COMPRESSED @@ -394,8 +397,8 @@ def compress( targets=self.sparsity_config.targets, ignore=self.sparsity_config.ignore, ) - state_dict = self.sparsity_compressor.compress( - state_dict, + compressed_state_dict = self.sparsity_compressor.compress( + compressed_state_dict, compression_targets=sparse_compression_targets, ) @@ -404,7 +407,7 @@ def compress( # https://github.com/huggingface/transformers/pull/30488 transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size - return state_dict + return compressed_state_dict def decompress(self, model_path: str, model: Module): """ @@ -602,15 +605,30 @@ def _replace_weights(self, dense_weight_generator, model: Module): update_parameter_data(module, param_data, param_name) -def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]: +def map_modules_to_quant_args( + model: Module, +) -> Dict[str, Union[QuantizationArgs, Tuple[QuantizationArgs, QuantizationArgs]]]: """ - Returns a dictionary which maps quantized module names to their quantization schemes + Given a pytorch model, map out the submodule name (usually linear layers) + to the weight QuantizationArgs. If running input activation quantization, will also + map to the input QuantizationArgs in a tuple. + + :param model: pytorch model """ - return { - fix_fsdp_module_name(name): module.quantization_scheme - for name, module in iter_named_leaf_modules(model) - if is_module_quantized(module) - } + quantized_modules_to_args = {} + for name, submodule in iter_named_leaf_modules(model): + if is_module_quantized(submodule): + if submodule.quantization_scheme.weights is not None: + name = fix_fsdp_module_name(name) + quantized_modules_to_args[name] = submodule.quantization_scheme.weights + if submodule.quantization_scheme.input_activations is not None: + weight_args = quantized_modules_to_args.get(name) + quantized_modules_to_args[name] = ( + weight_args, + submodule.quantization_scheme.input_activations, + ) + + return quantized_modules_to_args # HACK: Override the dtype_byte_size function in transformers to support float8 types diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 832cc4c0..098328be 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -14,16 +14,15 @@ import logging from pathlib import Path -from typing import Any, Dict, Generator, Tuple, Union +from typing import Any, Dict, Generator, Optional, Tuple, Union import torch from compressed_tensors.compressors.base import BaseCompressor -from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.utils import ( get_nested_mappings_from_state_dict, get_nested_weight_mappings, merge_names, - remove_suffix, ) from safetensors import safe_open from torch import Tensor @@ -70,7 +69,7 @@ class BaseQuantizationCompressor(BaseCompressor): def compress( self, model_state: Dict[str, Tensor], - names_to_scheme: Dict[str, QuantizationScheme], + names_to_scheme: Dict[str, QuantizationArgs], **kwargs, ) -> Dict[str, Tensor]: """ @@ -82,87 +81,87 @@ def compress( :return: compressed state dict """ compressed_dict = {} - save_device = "cpu" - - uncompressed_names = list(model_state.keys()) - for name in tqdm(uncompressed_names, desc="Compressing with quantization"): - value = model_state[name] - - # compress weights - if name.endswith("weight"): - prefix = remove_suffix(name, "weight") - - # gather qparams - scale = model_state.get(prefix + "weight_scale", None) - g_idx = model_state.get(prefix + "weight_g_idx", None) - zp = model_state.get(prefix + "weight_zero_point", None) - - # is scale does not exist, then weight cannot be compressed - if scale is None: - compressed_dict[name] = value.to(save_device) - continue - - # compress values on cpu (memory movement too expensive) - module_path = prefix[:-1] if prefix.endswith(".") else prefix - quant_args = names_to_scheme[module_path].weights - compressed_values = self.compress_weight( - weight=value, - scale=scale, - zero_point=zp, - g_idx=g_idx, - quantization_args=quant_args, - device="cpu", - ) - - # update state dict - for key, value in compressed_values.items(): - compressed_dict[prefix + key] = value.to(save_device) + weight_suffix = ".weight" + input_zp_suffix = ".input_zero_point" + weight_zp_suffix = ".weight_zero_point" + _LOGGER.debug( + f"Compressing model with {len(model_state)} parameterized layers..." + ) + for name, value in tqdm(model_state.items(), desc="Quantized Compression"): + # check if the parameter we're compressing is the weight zp + # or the input zp + is_weight_zp = name.endswith(weight_zp_suffix) + is_input_zp = name.endswith(input_zp_suffix) + + # if we're saving the weight zp, fetch weight quant args + if is_weight_zp: + quant_args_zp = names_to_scheme.get(name[: -(len(weight_zp_suffix))]) + if isinstance(quant_args_zp, tuple): + # If tuple, first value is weight args, second is input args + quant_args_zp = quant_args_zp[0] + + # if we're saving the input zp, fetch input quant args + if is_input_zp: + input_args_zp = names_to_scheme.get(name[: -(len(input_zp_suffix))]) + if isinstance(input_args_zp, tuple): + # If tuple, first value is weight args, second is input args + input_args_zp = input_args_zp[-1] + + if name.endswith(weight_suffix): + prefix = name[: -(len(weight_suffix))] + scale = model_state.get(merge_names(prefix, "weight_scale"), None) + zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) + g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None) + if scale is not None: + # weight is quantized, compress it + if isinstance(names_to_scheme[prefix], tuple): + quant_args = names_to_scheme[prefix][0] + else: + quant_args = names_to_scheme[prefix] + + compressed_data = self.compress_weight( + weight=value, + scale=scale, + zero_point=zp, + g_idx=g_idx, + quantization_args=quant_args, + device="cpu", + ) + for key, value in compressed_data.items(): + compressed_dict[merge_names(prefix, key)] = value + else: + compressed_dict[name] = value.to("cpu") + # only save zp if asym and not packed zp + elif is_weight_zp and ( + quant_args_zp.symmetric or self._check_if_zp_pack_quantized(quant_args) + ): + continue + # only save if asym + elif is_input_zp and input_args_zp.symmetric: + continue + elif name.endswith("g_idx") and torch.any(value <= -1): + continue else: - # omit saving zero points for symmetric or packed quantization - if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme): - continue - - # omit saving for g_idx if uninitialized - # TODO: does this case actually occur? - elif name.endswith("g_idx") and torch.any(value <= -1): - continue - - compressed_dict[name] = value.to(save_device) + compressed_dict[name] = value.to("cpu") return compressed_dict - def _skip_zp( - self, name: str, names_to_scheme: Dict[str, QuantizationScheme] - ) -> bool: + def _check_if_zp_pack_quantized(self, quant_args): from compressed_tensors.compressors import PackedQuantizationCompressor - module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name) - scheme = names_to_scheme[module_name] - - if zp_name == "weight_zero_point": - args = scheme.weights - if zp_name == "input_zero_point": - args = scheme.input_activations - if zp_name == "output_zero_point": - args = scheme.output_activations - - symmetric = args.symmetric - packable_strategies = [ - QuantizationStrategy.GROUP.value, - QuantizationStrategy.CHANNEL.value, - ] - packed = ( - isinstance(self, PackedQuantizationCompressor) - and args.strategy in packable_strategies - ) - - return symmetric or packed + if isinstance(self, PackedQuantizationCompressor): + if not quant_args.symmetric and quant_args.strategy in [ + QuantizationStrategy.GROUP.value, + QuantizationStrategy.CHANNEL.value, + ]: + return True + return False def decompress( self, path_to_model_or_tensors: Union[str, Path, Dict[str, Any]], - names_to_scheme: Dict[str, QuantizationScheme], + names_to_scheme: Dict[str, QuantizationArgs], device: str = "cpu", ) -> Generator[Tuple[str, Tensor], None, None]: """ @@ -171,9 +170,8 @@ def decompress( dense state dict :param path_to_model_or_tensors: path to compressed safetensors model (directory with one or more safetensors files) or compressed tensors file - :param names_to_scheme: quantization scheme for each quantized weight - :param device: optional device to load intermediate weights into (must be `str`, - not `torch.device`) + :param names_to_scheme: quantization args for each quantized weight + :param device: optional device to load intermediate weights into :return: compressed state dict """ if isinstance(path_to_model_or_tensors, (str, Path)): @@ -186,12 +184,7 @@ def decompress( path_to_model_or_tensors, names_to_scheme ) - def _decompress_from_path( - self, - path_to_model: Union[str, Path, Dict[str, Any]], - names_to_scheme: Dict[str, QuantizationScheme], - device: str, - ): + def _decompress_from_path(self, path_to_model, names_to_scheme, device): weight_mappings = get_nested_weight_mappings( path_to_model, self.compression_param_names ) @@ -202,7 +195,7 @@ def _decompress_from_path( with safe_open(safe_path, framework="pt", device=device) as f: weight_data[param_name] = f.get_tensor(full_name) if "weight_scale" in weight_data: - quant_args = names_to_scheme[weight_name].weights + quant_args = names_to_scheme[weight_name] decompressed = self.decompress_weight( compressed_data=weight_data, quantization_args=quant_args ) diff --git a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py index a8487cb7..24f9cbf0 100644 --- a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +++ b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py @@ -19,11 +19,7 @@ import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization import ( - QuantizationArgs, - QuantizationScheme, - QuantizationStrategy, -) +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization.lifecycle.forward import quantize from compressed_tensors.utils import ( get_permutations_24, @@ -48,25 +44,19 @@ class Marlin24Compressor(BaseCompressor): @staticmethod def validate_quant_compatability( - names_to_scheme: Dict[str, QuantizationScheme] + model_quant_args: Dict[str, QuantizationArgs] ) -> bool: """ Checks if every quantized module in the model is compatible with Marlin24 compression. Quantization must be channel or group strategy with group_size of 128. Only symmetric quantization is supported - :param names_to_scheme: dictionary of mapping module names to their - quantization schemes + :param model_quant_args: dictionary of mapping module names to their + quantization configuration :return: True if all modules are compatible with Marlin24 compression, raises a ValueError otherwise """ - for name, scheme in names_to_scheme.items(): - quant_args = scheme.weights - if quant_args is None: - raise ValueError( - "Marlin24 Compressor is only valid for weight quantization schemes" - ) - + for name, quant_args in model_quant_args.items(): strategy = quant_args.strategy group_size = quant_args.group_size symmetric = quant_args.symmetric @@ -124,7 +114,7 @@ def compression_param_names(self) -> Tuple[str]: def compress( self, model_state: Dict[str, Tensor], - names_to_scheme: Dict[str, QuantizationScheme], + names_to_scheme: Dict[str, QuantizationArgs], **kwargs, ) -> Dict[str, Tensor]: """ @@ -132,8 +122,8 @@ def compress( with the Marlin24 kernel :param model_state: state dict of uncompressed model - :param names_to_scheme: quantization scheme for each quantized weight, needed - for quantize function to calculate bit depth + :param names_to_scheme: quantization args for each quantized weight, needed for + quantize function to calculate bit depth :return: compressed state dict """ self.validate_quant_compatability(names_to_scheme) @@ -156,7 +146,7 @@ def compress( value = value.to(torch.float16) # quantize weight, keeping it as a float16 for now - quant_args = names_to_scheme[prefix].weights + quant_args = names_to_scheme[prefix] value = quantize( x=value, scale=scale, zero_point=zp, args=quant_args ) @@ -225,7 +215,7 @@ def pack_weight_24( weight: Tensor, quantization_args: QuantizationArgs, tile: int = 16, -) -> torch.Tensor: +): size_k = weight.shape[0] size_n = weight.shape[1] num_bits = quantization_args.num_bits @@ -246,9 +236,7 @@ def pack_weight_24( return q_packed -def pack_scales_24( - scales: torch.Tensor, quantization_args: QuantizationArgs, w_shape: torch.Size -) -> torch.Tensor: +def pack_scales_24(scales, quantization_args, w_shape): size_k = w_shape[0] size_n = w_shape[1] num_bits = quantization_args.num_bits diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index af0b9159..a842d00e 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -38,7 +38,6 @@ "shard_tensor", "pack_bitmasks", "unpack_bitmasks", - "remove_suffix", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -329,9 +328,3 @@ def unpack_bitmasks( ) return unpacked_bitmasks_torch - - -def remove_suffix(value: str, suffix: str) -> str: - # can replace with str.removesuffix in python3.9+ - assert value.endswith(suffix) - return value[: -len(suffix)] diff --git a/tests/test_compressors/quantized_compressors/test_fp8_quant.py b/tests/test_compressors/quantized_compressors/test_fp8_quant.py index e8bbc0d2..f5a85e57 100644 --- a/tests/test_compressors/quantized_compressors/test_fp8_quant.py +++ b/tests/test_compressors/quantized_compressors/test_fp8_quant.py @@ -84,9 +84,9 @@ def test_quant_format(strategy, group_size, sc, zp): quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size) compressor = FloatQuantizationCompressor(config=quant_config) - module_name_to_scheme = {"dummy": quant_config.config_groups["group_1"]} + quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} compressed_state_dict = compressor.compress( - dense_state_dict, names_to_scheme=module_name_to_scheme + dense_state_dict, names_to_scheme=quantized_modules_to_args ) # state_dict params should be the same, minus the zero_point if symmetric @@ -140,15 +140,15 @@ def test_reload_match( ) compressor = FloatQuantizationCompressor(config=quant_config) - module_name_to_scheme = { - "dummy": quant_config.config_groups["group_1"], + quantized_modules_to_args = { + "dummy": quant_config.config_groups["group_1"].weights, } compressed_state_dict = compressor.compress( - model.state_dict(), names_to_scheme=module_name_to_scheme + model.state_dict(), names_to_scheme=quantized_modules_to_args ) save_file(compressed_state_dict, tmp_path / "model.safetensors") reconstructed_dense_gen = compressor.decompress( - tmp_path, names_to_scheme=module_name_to_scheme + tmp_path, names_to_scheme=quantized_modules_to_args ) reconstructed_dense = {} for name, value in reconstructed_dense_gen: @@ -158,7 +158,7 @@ def test_reload_match( model.dummy.weight, scale=model.dummy.weight_scale, zero_point=model.dummy.weight_zero_point, - args=module_name_to_scheme["dummy"].weights, + args=quantized_modules_to_args["dummy"], ) assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy"].get("weight")) diff --git a/tests/test_compressors/quantized_compressors/test_int_quant.py b/tests/test_compressors/quantized_compressors/test_int_quant.py index 991444cc..ebbdb9cf 100644 --- a/tests/test_compressors/quantized_compressors/test_int_quant.py +++ b/tests/test_compressors/quantized_compressors/test_int_quant.py @@ -76,9 +76,9 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp): ) compressor = IntQuantizationCompressor(config=quant_config) - quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]} + quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} compressed_state_dict = compressor.compress( - dense_state_dict, names_to_scheme=quantized_modules_to_scheme + dense_state_dict, names_to_scheme=quantized_modules_to_args ) # state_dict params should be the same, minus the zero_point if symmetric @@ -124,16 +124,16 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path): quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size) compressor = IntQuantizationCompressor(config=quant_config) - module_name_to_scheme = { - "dummy": quant_config.config_groups["group_1"], - "dummy2": quant_config.config_groups["group_1"], + quantized_modules_to_args = { + "dummy": quant_config.config_groups["group_1"].weights, + "dummy2": quant_config.config_groups["group_1"].weights, } compressed_state_dict = compressor.compress( - dense_state_dict, names_to_scheme=module_name_to_scheme + dense_state_dict, names_to_scheme=quantized_modules_to_args ) save_file(compressed_state_dict, tmp_path / "model.safetensors") reconstructed_dense_gen = compressor.decompress( - tmp_path, names_to_scheme=module_name_to_scheme + tmp_path, names_to_scheme=quantized_modules_to_args ) reconstructed_dense = {} for name, value in reconstructed_dense_gen: @@ -143,7 +143,7 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path): dense_state_dict["dummy.weight"], scale=dense_state_dict["dummy.weight_scale"], zero_point=dense_state_dict["dummy.weight_zero_point"], - args=module_name_to_scheme["dummy"].weights, + args=quantized_modules_to_args["dummy"], ) assert torch.equal( fake_quant_dummy, reconstructed_dense["dummy"].get("weight").to(torch.float32) @@ -153,7 +153,7 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path): dense_state_dict["dummy2.weight"], scale=dense_state_dict["dummy2.weight_scale"], zero_point=dense_state_dict["dummy2.weight_zero_point"], - args=module_name_to_scheme["dummy2"].weights, + args=quantized_modules_to_args["dummy2"], ) assert torch.equal( fake_quant_dummy2, reconstructed_dense["dummy2"].get("weight").to(torch.float32) diff --git a/tests/test_compressors/quantized_compressors/test_pack_quant.py b/tests/test_compressors/quantized_compressors/test_pack_quant.py index 00d61275..dc8b69bd 100644 --- a/tests/test_compressors/quantized_compressors/test_pack_quant.py +++ b/tests/test_compressors/quantized_compressors/test_pack_quant.py @@ -40,7 +40,7 @@ def get_dummy_quant_config( num_bits=4, strategy=None, group_size=None, actorder=None, symmetric=True -) -> QuantizationConfig: +): config_groups = { "group_1": QuantizationScheme( targets=["Linear"], @@ -82,9 +82,9 @@ def test_quant_format(shape): quant_config = get_dummy_quant_config() compressor = PackedQuantizationCompressor(config=quant_config) - quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]} + quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} compressed_state_dict = compressor.compress( - dense_state_dict, names_to_scheme=quantized_modules_to_scheme + dense_state_dict, names_to_scheme=quantized_modules_to_args ) # compressed state_dict adds one entry for shape @@ -157,21 +157,25 @@ def test_reload_match(tmp_path, num_bits): # pack-compressor only needs the number of bits from the quant-args to decompress # all other information is extracted from the compressed data directly + names_to_scheme = { + "dummy": QuantizationArgs(num_bits=num_bits), + "dummy2": QuantizationArgs(num_bits=num_bits), + } quant_config = get_dummy_quant_config(num_bits, symmetric=False) compressor = PackedQuantizationCompressor(config=quant_config) - quantized_modules_to_scheme = { - "dummy": quant_config.config_groups["group_1"], - "dummy2": quant_config.config_groups["group_1"], + quantized_modules_to_args = { + "dummy": quant_config.config_groups["group_1"].weights, + "dummy2": quant_config.config_groups["group_1"].weights, } compressed_state_dict = compressor.compress( - dense_state_dict.copy(), names_to_scheme=quantized_modules_to_scheme + dense_state_dict, names_to_scheme=quantized_modules_to_args ) save_file(compressed_state_dict, tmp_path / "model.safetensors") reconstructed_dense_gen = compressor.decompress( - tmp_path, names_to_scheme=quantized_modules_to_scheme + tmp_path, names_to_scheme=names_to_scheme ) reconstructed_dense = {} for name, value in reconstructed_dense_gen: @@ -181,7 +185,7 @@ def test_reload_match(tmp_path, num_bits): dense_state_dict["dummy.weight"], scale=dense_state_dict["dummy.weight_scale"], zero_point=dense_state_dict["dummy.weight_zero_point"], - args=quantized_modules_to_scheme["dummy"].weights, + args=quantized_modules_to_args["dummy"], ) assert torch.equal( fake_quant_dummy, reconstructed_dense["dummy"].get("weight").to(torch.float32) @@ -191,7 +195,7 @@ def test_reload_match(tmp_path, num_bits): dense_state_dict["dummy2.weight"], scale=dense_state_dict["dummy2.weight_scale"], zero_point=dense_state_dict["dummy2.weight_zero_point"], - args=quantized_modules_to_scheme["dummy2"].weights, + args=quantized_modules_to_args["dummy2"], ) assert torch.equal( fake_quant_dummy2, reconstructed_dense["dummy2"].get("weight").to(torch.float32) @@ -228,9 +232,9 @@ def test_asymmetric_packed_support(strategy): ) compressor = PackedQuantizationCompressor(config=quant_config) - quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]} + quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} compressed_state_dict = compressor.compress( - dense_state_dict, names_to_scheme=quantized_modules_to_scheme + dense_state_dict, names_to_scheme=quantized_modules_to_args ) # compressed state_dict adds one entry for shape @@ -285,17 +289,17 @@ def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration): # compress compressor = PackedQuantizationCompressor(config=quant_config) - quantized_modules_to_scheme = { - "dummy": quant_config.config_groups["group_1"], + quantized_modules_to_args = { + "dummy": quant_config.config_groups["group_1"].weights, } compressed_state_dict = compressor.compress( - model.state_dict(), names_to_scheme=quantized_modules_to_scheme + model.state_dict(), names_to_scheme=quantized_modules_to_args ) save_file(compressed_state_dict, tmp_path / "model.safetensors") # decompress reconstructed_dense_gen = compressor.decompress( - tmp_path, names_to_scheme=quantized_modules_to_scheme + tmp_path, names_to_scheme=quantized_modules_to_args ) reconstructed_dense = {} for name, value in reconstructed_dense_gen: @@ -306,7 +310,7 @@ def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration): scale=model.dummy.weight_scale, zero_point=model.dummy.weight_zero_point, g_idx=getattr(model.dummy, "weight_g_idx", None), - args=quantized_modules_to_scheme["dummy"].weights, + args=quantized_modules_to_args["dummy"], ) assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy"].get("weight")) diff --git a/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py b/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py index 598a69c5..ddc51110 100644 --- a/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +++ b/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py @@ -19,7 +19,7 @@ from compressed_tensors.compressors import ( BaseCompressor, Marlin24Compressor, - map_module_to_scheme, + map_modules_to_quant_args, ) from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization import ( @@ -92,9 +92,9 @@ def test_marlin24_format( assert f"{NOT_QUANT_NAME}.weight_scale" not in state_dict assert f"{QUANT_NAME}.weight_scale" in state_dict - module_to_scheme = map_module_to_scheme(model) + model_to_quant_args = map_modules_to_quant_args(model) compressor = Marlin24Compressor() - compressor.validate_quant_compatability(module_to_scheme) + compressor.validate_quant_compatability(model_to_quant_args) compressor.validate_sparsity_structure( QUANT_NAME, state_dict[f"{QUANT_NAME}.weight"] ) @@ -104,7 +104,7 @@ def test_marlin24_format( ) compressor = Marlin24Compressor() - compressed_state_dict = compressor.compress(state_dict, module_to_scheme) + compressed_state_dict = compressor.compress(state_dict, model_to_quant_args) assert len(compressed_state_dict) == 4 assert torch.equal(