diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 7a7a5e88..13a9f375 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -19,7 +19,7 @@ import re from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union import compressed_tensors import torch @@ -36,12 +36,12 @@ from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_METHOD, QuantizationConfig, + QuantizationScheme, QuantizationStatus, apply_quantization_config, load_pretrained_quantization_parameters, ) from compressed_tensors.quantization.lifecycle import expand_target_names -from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.utils import ( is_module_quantized, iter_named_leaf_modules, @@ -64,7 +64,7 @@ from transformers.file_utils import CONFIG_NAME -__all__ = ["ModelCompressor", "map_modules_to_quant_args"] +__all__ = ["ModelCompressor", "map_module_to_scheme"] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -372,20 +372,17 @@ def compress( :param state_dict: optional uncompressed state_dict to insert into model :return: compressed state dict """ + if state_dict is None: state_dict = model.state_dict() - compressed_state_dict = state_dict - - quantized_modules_to_args: Dict[ - str, QuantizationArgs - ] = map_modules_to_quant_args(model) - if self.quantization_compressor is not None: - compressed_state_dict = self.quantization_compressor.compress( - state_dict, names_to_scheme=quantized_modules_to_args + module_to_scheme = map_module_to_scheme(model) + state_dict = self.quantization_compressor.compress( + state_dict, names_to_scheme=module_to_scheme ) + # TODO: consider sparse compression to also be compression if self.quantization_config.format != CompressionFormat.dense.value: self.quantization_config.quantization_status = ( QuantizationStatus.COMPRESSED @@ -397,8 +394,8 @@ def compress( targets=self.sparsity_config.targets, ignore=self.sparsity_config.ignore, ) - compressed_state_dict = self.sparsity_compressor.compress( - compressed_state_dict, + state_dict = self.sparsity_compressor.compress( + state_dict, compression_targets=sparse_compression_targets, ) @@ -407,7 +404,7 @@ def compress( # https://github.com/huggingface/transformers/pull/30488 transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size - return compressed_state_dict + return state_dict def decompress(self, model_path: str, model: Module): """ @@ -605,30 +602,15 @@ def _replace_weights(self, dense_weight_generator, model: Module): update_parameter_data(module, param_data, param_name) -def map_modules_to_quant_args( - model: Module, -) -> Dict[str, Union[QuantizationArgs, Tuple[QuantizationArgs, QuantizationArgs]]]: +def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]: """ - Given a pytorch model, map out the submodule name (usually linear layers) - to the weight QuantizationArgs. If running input activation quantization, will also - map to the input QuantizationArgs in a tuple. - - :param model: pytorch model + Returns a dictionary which maps quantized module names to their quantization schemes """ - quantized_modules_to_args = {} - for name, submodule in iter_named_leaf_modules(model): - if is_module_quantized(submodule): - if submodule.quantization_scheme.weights is not None: - name = fix_fsdp_module_name(name) - quantized_modules_to_args[name] = submodule.quantization_scheme.weights - if submodule.quantization_scheme.input_activations is not None: - weight_args = quantized_modules_to_args.get(name) - quantized_modules_to_args[name] = ( - weight_args, - submodule.quantization_scheme.input_activations, - ) - - return quantized_modules_to_args + return { + fix_fsdp_module_name(name): module.quantization_scheme + for name, module in iter_named_leaf_modules(model) + if is_module_quantized(module) + } # HACK: Override the dtype_byte_size function in transformers to support float8 types diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 098328be..832cc4c0 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -14,15 +14,16 @@ import logging from pathlib import Path -from typing import Any, Dict, Generator, Optional, Tuple, Union +from typing import Any, Dict, Generator, Tuple, Union import torch from compressed_tensors.compressors.base import BaseCompressor -from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy +from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy from compressed_tensors.utils import ( get_nested_mappings_from_state_dict, get_nested_weight_mappings, merge_names, + remove_suffix, ) from safetensors import safe_open from torch import Tensor @@ -69,7 +70,7 @@ class BaseQuantizationCompressor(BaseCompressor): def compress( self, model_state: Dict[str, Tensor], - names_to_scheme: Dict[str, QuantizationArgs], + names_to_scheme: Dict[str, QuantizationScheme], **kwargs, ) -> Dict[str, Tensor]: """ @@ -81,87 +82,87 @@ def compress( :return: compressed state dict """ compressed_dict = {} - weight_suffix = ".weight" - input_zp_suffix = ".input_zero_point" - weight_zp_suffix = ".weight_zero_point" - _LOGGER.debug( - f"Compressing model with {len(model_state)} parameterized layers..." - ) + save_device = "cpu" + + uncompressed_names = list(model_state.keys()) + for name in tqdm(uncompressed_names, desc="Compressing with quantization"): + value = model_state[name] + + # compress weights + if name.endswith("weight"): + prefix = remove_suffix(name, "weight") + + # gather qparams + scale = model_state.get(prefix + "weight_scale", None) + g_idx = model_state.get(prefix + "weight_g_idx", None) + zp = model_state.get(prefix + "weight_zero_point", None) + + # is scale does not exist, then weight cannot be compressed + if scale is None: + compressed_dict[name] = value.to(save_device) + continue + + # compress values on cpu (memory movement too expensive) + module_path = prefix[:-1] if prefix.endswith(".") else prefix + quant_args = names_to_scheme[module_path].weights + compressed_values = self.compress_weight( + weight=value, + scale=scale, + zero_point=zp, + g_idx=g_idx, + quantization_args=quant_args, + device="cpu", + ) + + # update state dict + for key, value in compressed_values.items(): + compressed_dict[prefix + key] = value.to(save_device) - for name, value in tqdm(model_state.items(), desc="Quantized Compression"): - # check if the parameter we're compressing is the weight zp - # or the input zp - is_weight_zp = name.endswith(weight_zp_suffix) - is_input_zp = name.endswith(input_zp_suffix) - - # if we're saving the weight zp, fetch weight quant args - if is_weight_zp: - quant_args_zp = names_to_scheme.get(name[: -(len(weight_zp_suffix))]) - if isinstance(quant_args_zp, tuple): - # If tuple, first value is weight args, second is input args - quant_args_zp = quant_args_zp[0] - - # if we're saving the input zp, fetch input quant args - if is_input_zp: - input_args_zp = names_to_scheme.get(name[: -(len(input_zp_suffix))]) - if isinstance(input_args_zp, tuple): - # If tuple, first value is weight args, second is input args - input_args_zp = input_args_zp[-1] - - if name.endswith(weight_suffix): - prefix = name[: -(len(weight_suffix))] - scale = model_state.get(merge_names(prefix, "weight_scale"), None) - zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) - g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None) - if scale is not None: - # weight is quantized, compress it - if isinstance(names_to_scheme[prefix], tuple): - quant_args = names_to_scheme[prefix][0] - else: - quant_args = names_to_scheme[prefix] - - compressed_data = self.compress_weight( - weight=value, - scale=scale, - zero_point=zp, - g_idx=g_idx, - quantization_args=quant_args, - device="cpu", - ) - for key, value in compressed_data.items(): - compressed_dict[merge_names(prefix, key)] = value - else: - compressed_dict[name] = value.to("cpu") - # only save zp if asym and not packed zp - elif is_weight_zp and ( - quant_args_zp.symmetric or self._check_if_zp_pack_quantized(quant_args) - ): - continue - # only save if asym - elif is_input_zp and input_args_zp.symmetric: - continue - elif name.endswith("g_idx") and torch.any(value <= -1): - continue else: - compressed_dict[name] = value.to("cpu") + # omit saving zero points for symmetric or packed quantization + if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme): + continue + + # omit saving for g_idx if uninitialized + # TODO: does this case actually occur? + elif name.endswith("g_idx") and torch.any(value <= -1): + continue + + compressed_dict[name] = value.to(save_device) return compressed_dict - def _check_if_zp_pack_quantized(self, quant_args): + def _skip_zp( + self, name: str, names_to_scheme: Dict[str, QuantizationScheme] + ) -> bool: from compressed_tensors.compressors import PackedQuantizationCompressor - if isinstance(self, PackedQuantizationCompressor): - if not quant_args.symmetric and quant_args.strategy in [ - QuantizationStrategy.GROUP.value, - QuantizationStrategy.CHANNEL.value, - ]: - return True - return False + module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name) + scheme = names_to_scheme[module_name] + + if zp_name == "weight_zero_point": + args = scheme.weights + if zp_name == "input_zero_point": + args = scheme.input_activations + if zp_name == "output_zero_point": + args = scheme.output_activations + + symmetric = args.symmetric + packable_strategies = [ + QuantizationStrategy.GROUP.value, + QuantizationStrategy.CHANNEL.value, + ] + packed = ( + isinstance(self, PackedQuantizationCompressor) + and args.strategy in packable_strategies + ) + + return symmetric or packed def decompress( self, path_to_model_or_tensors: Union[str, Path, Dict[str, Any]], - names_to_scheme: Dict[str, QuantizationArgs], + names_to_scheme: Dict[str, QuantizationScheme], device: str = "cpu", ) -> Generator[Tuple[str, Tensor], None, None]: """ @@ -170,8 +171,9 @@ def decompress( dense state dict :param path_to_model_or_tensors: path to compressed safetensors model (directory with one or more safetensors files) or compressed tensors file - :param names_to_scheme: quantization args for each quantized weight - :param device: optional device to load intermediate weights into + :param names_to_scheme: quantization scheme for each quantized weight + :param device: optional device to load intermediate weights into (must be `str`, + not `torch.device`) :return: compressed state dict """ if isinstance(path_to_model_or_tensors, (str, Path)): @@ -184,7 +186,12 @@ def decompress( path_to_model_or_tensors, names_to_scheme ) - def _decompress_from_path(self, path_to_model, names_to_scheme, device): + def _decompress_from_path( + self, + path_to_model: Union[str, Path, Dict[str, Any]], + names_to_scheme: Dict[str, QuantizationScheme], + device: str, + ): weight_mappings = get_nested_weight_mappings( path_to_model, self.compression_param_names ) @@ -195,7 +202,7 @@ def _decompress_from_path(self, path_to_model, names_to_scheme, device): with safe_open(safe_path, framework="pt", device=device) as f: weight_data[param_name] = f.get_tensor(full_name) if "weight_scale" in weight_data: - quant_args = names_to_scheme[weight_name] + quant_args = names_to_scheme[weight_name].weights decompressed = self.decompress_weight( compressed_data=weight_data, quantization_args=quant_args ) diff --git a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py index 24f9cbf0..a8487cb7 100644 --- a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +++ b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py @@ -19,7 +19,11 @@ import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, +) from compressed_tensors.quantization.lifecycle.forward import quantize from compressed_tensors.utils import ( get_permutations_24, @@ -44,19 +48,25 @@ class Marlin24Compressor(BaseCompressor): @staticmethod def validate_quant_compatability( - model_quant_args: Dict[str, QuantizationArgs] + names_to_scheme: Dict[str, QuantizationScheme] ) -> bool: """ Checks if every quantized module in the model is compatible with Marlin24 compression. Quantization must be channel or group strategy with group_size of 128. Only symmetric quantization is supported - :param model_quant_args: dictionary of mapping module names to their - quantization configuration + :param names_to_scheme: dictionary of mapping module names to their + quantization schemes :return: True if all modules are compatible with Marlin24 compression, raises a ValueError otherwise """ - for name, quant_args in model_quant_args.items(): + for name, scheme in names_to_scheme.items(): + quant_args = scheme.weights + if quant_args is None: + raise ValueError( + "Marlin24 Compressor is only valid for weight quantization schemes" + ) + strategy = quant_args.strategy group_size = quant_args.group_size symmetric = quant_args.symmetric @@ -114,7 +124,7 @@ def compression_param_names(self) -> Tuple[str]: def compress( self, model_state: Dict[str, Tensor], - names_to_scheme: Dict[str, QuantizationArgs], + names_to_scheme: Dict[str, QuantizationScheme], **kwargs, ) -> Dict[str, Tensor]: """ @@ -122,8 +132,8 @@ def compress( with the Marlin24 kernel :param model_state: state dict of uncompressed model - :param names_to_scheme: quantization args for each quantized weight, needed for - quantize function to calculate bit depth + :param names_to_scheme: quantization scheme for each quantized weight, needed + for quantize function to calculate bit depth :return: compressed state dict """ self.validate_quant_compatability(names_to_scheme) @@ -146,7 +156,7 @@ def compress( value = value.to(torch.float16) # quantize weight, keeping it as a float16 for now - quant_args = names_to_scheme[prefix] + quant_args = names_to_scheme[prefix].weights value = quantize( x=value, scale=scale, zero_point=zp, args=quant_args ) @@ -215,7 +225,7 @@ def pack_weight_24( weight: Tensor, quantization_args: QuantizationArgs, tile: int = 16, -): +) -> torch.Tensor: size_k = weight.shape[0] size_n = weight.shape[1] num_bits = quantization_args.num_bits @@ -236,7 +246,9 @@ def pack_weight_24( return q_packed -def pack_scales_24(scales, quantization_args, w_shape): +def pack_scales_24( + scales: torch.Tensor, quantization_args: QuantizationArgs, w_shape: torch.Size +) -> torch.Tensor: size_k = w_shape[0] size_n = w_shape[1] num_bits = quantization_args.num_bits diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index a842d00e..af0b9159 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -38,6 +38,7 @@ "shard_tensor", "pack_bitmasks", "unpack_bitmasks", + "remove_suffix", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -328,3 +329,9 @@ def unpack_bitmasks( ) return unpacked_bitmasks_torch + + +def remove_suffix(value: str, suffix: str) -> str: + # can replace with str.removesuffix in python3.9+ + assert value.endswith(suffix) + return value[: -len(suffix)] diff --git a/tests/test_compressors/quantized_compressors/test_fp8_quant.py b/tests/test_compressors/quantized_compressors/test_fp8_quant.py index f5a85e57..e8bbc0d2 100644 --- a/tests/test_compressors/quantized_compressors/test_fp8_quant.py +++ b/tests/test_compressors/quantized_compressors/test_fp8_quant.py @@ -84,9 +84,9 @@ def test_quant_format(strategy, group_size, sc, zp): quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size) compressor = FloatQuantizationCompressor(config=quant_config) - quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} + module_name_to_scheme = {"dummy": quant_config.config_groups["group_1"]} compressed_state_dict = compressor.compress( - dense_state_dict, names_to_scheme=quantized_modules_to_args + dense_state_dict, names_to_scheme=module_name_to_scheme ) # state_dict params should be the same, minus the zero_point if symmetric @@ -140,15 +140,15 @@ def test_reload_match( ) compressor = FloatQuantizationCompressor(config=quant_config) - quantized_modules_to_args = { - "dummy": quant_config.config_groups["group_1"].weights, + module_name_to_scheme = { + "dummy": quant_config.config_groups["group_1"], } compressed_state_dict = compressor.compress( - model.state_dict(), names_to_scheme=quantized_modules_to_args + model.state_dict(), names_to_scheme=module_name_to_scheme ) save_file(compressed_state_dict, tmp_path / "model.safetensors") reconstructed_dense_gen = compressor.decompress( - tmp_path, names_to_scheme=quantized_modules_to_args + tmp_path, names_to_scheme=module_name_to_scheme ) reconstructed_dense = {} for name, value in reconstructed_dense_gen: @@ -158,7 +158,7 @@ def test_reload_match( model.dummy.weight, scale=model.dummy.weight_scale, zero_point=model.dummy.weight_zero_point, - args=quantized_modules_to_args["dummy"], + args=module_name_to_scheme["dummy"].weights, ) assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy"].get("weight")) diff --git a/tests/test_compressors/quantized_compressors/test_int_quant.py b/tests/test_compressors/quantized_compressors/test_int_quant.py index ebbdb9cf..991444cc 100644 --- a/tests/test_compressors/quantized_compressors/test_int_quant.py +++ b/tests/test_compressors/quantized_compressors/test_int_quant.py @@ -76,9 +76,9 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp): ) compressor = IntQuantizationCompressor(config=quant_config) - quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} + quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]} compressed_state_dict = compressor.compress( - dense_state_dict, names_to_scheme=quantized_modules_to_args + dense_state_dict, names_to_scheme=quantized_modules_to_scheme ) # state_dict params should be the same, minus the zero_point if symmetric @@ -124,16 +124,16 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path): quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size) compressor = IntQuantizationCompressor(config=quant_config) - quantized_modules_to_args = { - "dummy": quant_config.config_groups["group_1"].weights, - "dummy2": quant_config.config_groups["group_1"].weights, + module_name_to_scheme = { + "dummy": quant_config.config_groups["group_1"], + "dummy2": quant_config.config_groups["group_1"], } compressed_state_dict = compressor.compress( - dense_state_dict, names_to_scheme=quantized_modules_to_args + dense_state_dict, names_to_scheme=module_name_to_scheme ) save_file(compressed_state_dict, tmp_path / "model.safetensors") reconstructed_dense_gen = compressor.decompress( - tmp_path, names_to_scheme=quantized_modules_to_args + tmp_path, names_to_scheme=module_name_to_scheme ) reconstructed_dense = {} for name, value in reconstructed_dense_gen: @@ -143,7 +143,7 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path): dense_state_dict["dummy.weight"], scale=dense_state_dict["dummy.weight_scale"], zero_point=dense_state_dict["dummy.weight_zero_point"], - args=quantized_modules_to_args["dummy"], + args=module_name_to_scheme["dummy"].weights, ) assert torch.equal( fake_quant_dummy, reconstructed_dense["dummy"].get("weight").to(torch.float32) @@ -153,7 +153,7 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path): dense_state_dict["dummy2.weight"], scale=dense_state_dict["dummy2.weight_scale"], zero_point=dense_state_dict["dummy2.weight_zero_point"], - args=quantized_modules_to_args["dummy2"], + args=module_name_to_scheme["dummy2"].weights, ) assert torch.equal( fake_quant_dummy2, reconstructed_dense["dummy2"].get("weight").to(torch.float32) diff --git a/tests/test_compressors/quantized_compressors/test_pack_quant.py b/tests/test_compressors/quantized_compressors/test_pack_quant.py index dc8b69bd..00d61275 100644 --- a/tests/test_compressors/quantized_compressors/test_pack_quant.py +++ b/tests/test_compressors/quantized_compressors/test_pack_quant.py @@ -40,7 +40,7 @@ def get_dummy_quant_config( num_bits=4, strategy=None, group_size=None, actorder=None, symmetric=True -): +) -> QuantizationConfig: config_groups = { "group_1": QuantizationScheme( targets=["Linear"], @@ -82,9 +82,9 @@ def test_quant_format(shape): quant_config = get_dummy_quant_config() compressor = PackedQuantizationCompressor(config=quant_config) - quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} + quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]} compressed_state_dict = compressor.compress( - dense_state_dict, names_to_scheme=quantized_modules_to_args + dense_state_dict, names_to_scheme=quantized_modules_to_scheme ) # compressed state_dict adds one entry for shape @@ -157,25 +157,21 @@ def test_reload_match(tmp_path, num_bits): # pack-compressor only needs the number of bits from the quant-args to decompress # all other information is extracted from the compressed data directly - names_to_scheme = { - "dummy": QuantizationArgs(num_bits=num_bits), - "dummy2": QuantizationArgs(num_bits=num_bits), - } quant_config = get_dummy_quant_config(num_bits, symmetric=False) compressor = PackedQuantizationCompressor(config=quant_config) - quantized_modules_to_args = { - "dummy": quant_config.config_groups["group_1"].weights, - "dummy2": quant_config.config_groups["group_1"].weights, + quantized_modules_to_scheme = { + "dummy": quant_config.config_groups["group_1"], + "dummy2": quant_config.config_groups["group_1"], } compressed_state_dict = compressor.compress( - dense_state_dict, names_to_scheme=quantized_modules_to_args + dense_state_dict.copy(), names_to_scheme=quantized_modules_to_scheme ) save_file(compressed_state_dict, tmp_path / "model.safetensors") reconstructed_dense_gen = compressor.decompress( - tmp_path, names_to_scheme=names_to_scheme + tmp_path, names_to_scheme=quantized_modules_to_scheme ) reconstructed_dense = {} for name, value in reconstructed_dense_gen: @@ -185,7 +181,7 @@ def test_reload_match(tmp_path, num_bits): dense_state_dict["dummy.weight"], scale=dense_state_dict["dummy.weight_scale"], zero_point=dense_state_dict["dummy.weight_zero_point"], - args=quantized_modules_to_args["dummy"], + args=quantized_modules_to_scheme["dummy"].weights, ) assert torch.equal( fake_quant_dummy, reconstructed_dense["dummy"].get("weight").to(torch.float32) @@ -195,7 +191,7 @@ def test_reload_match(tmp_path, num_bits): dense_state_dict["dummy2.weight"], scale=dense_state_dict["dummy2.weight_scale"], zero_point=dense_state_dict["dummy2.weight_zero_point"], - args=quantized_modules_to_args["dummy2"], + args=quantized_modules_to_scheme["dummy2"].weights, ) assert torch.equal( fake_quant_dummy2, reconstructed_dense["dummy2"].get("weight").to(torch.float32) @@ -232,9 +228,9 @@ def test_asymmetric_packed_support(strategy): ) compressor = PackedQuantizationCompressor(config=quant_config) - quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} + quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]} compressed_state_dict = compressor.compress( - dense_state_dict, names_to_scheme=quantized_modules_to_args + dense_state_dict, names_to_scheme=quantized_modules_to_scheme ) # compressed state_dict adds one entry for shape @@ -289,17 +285,17 @@ def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration): # compress compressor = PackedQuantizationCompressor(config=quant_config) - quantized_modules_to_args = { - "dummy": quant_config.config_groups["group_1"].weights, + quantized_modules_to_scheme = { + "dummy": quant_config.config_groups["group_1"], } compressed_state_dict = compressor.compress( - model.state_dict(), names_to_scheme=quantized_modules_to_args + model.state_dict(), names_to_scheme=quantized_modules_to_scheme ) save_file(compressed_state_dict, tmp_path / "model.safetensors") # decompress reconstructed_dense_gen = compressor.decompress( - tmp_path, names_to_scheme=quantized_modules_to_args + tmp_path, names_to_scheme=quantized_modules_to_scheme ) reconstructed_dense = {} for name, value in reconstructed_dense_gen: @@ -310,7 +306,7 @@ def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration): scale=model.dummy.weight_scale, zero_point=model.dummy.weight_zero_point, g_idx=getattr(model.dummy, "weight_g_idx", None), - args=quantized_modules_to_args["dummy"], + args=quantized_modules_to_scheme["dummy"].weights, ) assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy"].get("weight")) diff --git a/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py b/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py index ddc51110..598a69c5 100644 --- a/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +++ b/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py @@ -19,7 +19,7 @@ from compressed_tensors.compressors import ( BaseCompressor, Marlin24Compressor, - map_modules_to_quant_args, + map_module_to_scheme, ) from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization import ( @@ -92,9 +92,9 @@ def test_marlin24_format( assert f"{NOT_QUANT_NAME}.weight_scale" not in state_dict assert f"{QUANT_NAME}.weight_scale" in state_dict - model_to_quant_args = map_modules_to_quant_args(model) + module_to_scheme = map_module_to_scheme(model) compressor = Marlin24Compressor() - compressor.validate_quant_compatability(model_to_quant_args) + compressor.validate_quant_compatability(module_to_scheme) compressor.validate_sparsity_structure( QUANT_NAME, state_dict[f"{QUANT_NAME}.weight"] ) @@ -104,7 +104,7 @@ def test_marlin24_format( ) compressor = Marlin24Compressor() - compressed_state_dict = compressor.compress(state_dict, model_to_quant_args) + compressed_state_dict = compressor.compress(state_dict, module_to_scheme) assert len(compressed_state_dict) == 4 assert torch.equal(