From 669681043d38ad1e1d07ea14fd1d85d4afa86a67 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Wed, 2 Jul 2025 09:56:29 -0400 Subject: [PATCH 01/12] added support for compression on meta device Signed-off-by: shanjiaz --- .../model_compressors/model_compressor.py | 21 ++++++---- .../compressors/quantized_compressors/base.py | 7 ++-- .../quantized_compressors/nvfp4_quantized.py | 1 + .../quantized_compressors/pack_quantized.py | 42 +++++++++++-------- .../compressors/sparse_compressors/base.py | 8 +++- .../sparse_compressors/sparse_24_bitmask.py | 34 ++++++++++++--- .../sparse_compressors/sparse_bitmask.py | 2 +- 7 files changed, 78 insertions(+), 37 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 700c1769..ad9b66cc 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -378,7 +378,7 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]: # ----- model memory compression/decompression pathways ----- # - def compress_model(self, model: Module): + def compress_model(self, model: Module, is_meta: bool = False): """ Compress a model in memory. Because the model structure is modified in place, this method is more memory-efficient than `self.compress` @@ -394,8 +394,11 @@ def compress_model(self, model: Module): for prefix, module in tqdm(model.named_modules(), desc="Compressing model"): if prefix in module_to_scheme or prefix in sparse_compression_targets: + exec_device = "meta" if is_meta else "cpu" + onloading_device = "meta" if is_meta else get_execution_device(module) + # in the future, support compression on same device - with align_module_device(module, execution_device="cpu"): + with align_module_device(module, execution_device=exec_device): state_dict = module.state_dict(prefix=f"{prefix}.") # quantization first @@ -404,6 +407,7 @@ def compress_model(self, model: Module): state_dict, names_to_scheme=module_to_scheme, show_progress=False, + save_device=exec_device, ) # sparsity second @@ -412,10 +416,10 @@ def compress_model(self, model: Module): state_dict, compression_targets=sparse_compression_targets, show_progress=False, + module=module, ) # remove any existing parameters - exec_device = get_execution_device(module) offload_device = get_offloaded_device(module) for name, _ in list(module.named_parameters()): delete_offload_parameter(module, name) @@ -423,7 +427,7 @@ def compress_model(self, model: Module): # replace with compressed parameters for name, value in state_dict.items(): name = name.removeprefix(f"{prefix}.") - value = value.to(exec_device) + value = value.to(onloading_device) param = torch.nn.Parameter(value, requires_grad=False) register_offload_parameter(module, name, param, offload_device) @@ -485,7 +489,9 @@ def decompress_model(self, model: Module): # replace with decompressed parameters for name, value in state_dict.items(): name = name.removeprefix(f"{prefix}.") - value = value.to(exec_device) + # skipping save if we're just registering the model on meta device + if exec_device != "meta": + value = value.to(exec_device) param = torch.nn.Parameter(value, requires_grad=False) register_offload_parameter(module, name, param, offload_device) @@ -693,7 +699,9 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module): params_device = next(module.parameters()).device device = "cpu" if has_offloaded_params(module) else params_device delattr(module, param_name) - requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16) + #requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16) + requires_grad = torch.is_floating_point(data) + param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad) register_offload_parameter(module, param_name, param) @@ -711,7 +719,6 @@ def _replace_weights(self, dense_weight_generator, model: Module): 'data' is the updated param data :param model: The model whose weights are to be updated. """ - for mod_path, data in tqdm(dense_weight_generator, desc="Decompressing model"): module = operator.attrgetter(mod_path)(model) diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index d0a07302..b6c0fbcf 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -72,6 +72,7 @@ def compress( model_state: Dict[str, Tensor], names_to_scheme: Dict[str, QuantizationScheme], show_progress: bool = False, + save_device: str = "cpu", **kwargs, ) -> Dict[str, Tensor]: """ @@ -85,7 +86,6 @@ def compress( """ uncompressed_names = list(model_state.keys()) compressed_dict = {} - save_device = "cpu" # compress values desc = "Compressing with quantization" @@ -107,7 +107,7 @@ def compress( compressed_dict[name] = value.to(save_device) continue - # compress values on cpu (memory movement too expensive) + # compress values on meta if loading from meta otherwise on cpu (memory movement too expensive) module_path = prefix[:-1] if prefix.endswith(".") else prefix quant_args = names_to_scheme[module_path].weights compressed_values = self.compress_weight( @@ -117,7 +117,7 @@ def compress( global_scale=global_scale, g_idx=g_idx, quantization_args=quant_args, - device="cpu", + device=save_device, ) # update state dict @@ -133,7 +133,6 @@ def compress( # TODO: does this case actually occur? elif name.endswith("g_idx") and torch.any(value <= -1): continue - compressed_dict[name] = value.to(save_device) return compressed_dict diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 5f348e91..41131589 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -82,6 +82,7 @@ def compress_weight( compressed_dict = {} weight_packed = pack_fp4_to_uint8(quantized_weight) if device is not None: + breakpoint() weight_packed = weight_packed.to(device) compressed_dict["weight_packed"] = weight_packed return compressed_dict diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index 902ac26c..4242207e 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -220,30 +220,36 @@ def pack_to_int32( if num_bits < 1: raise ValueError(f"num_bits must be at least 1, got {num_bits}") - # convert to unsigned for packing + # Convert to unsigned range for packing, matching quantization offset offset = 1 << (num_bits - 1) value = (value + offset).to(torch.uint8) - value = value.cpu().numpy().astype(np.uint32) + device = value.device + pack_factor = 32 // num_bits - # pad input tensor and initialize packed output - packed_size = math.ceil(value.shape[packed_dim] / pack_factor) - padding = packed_size * pack_factor - value.shape[packed_dim] - value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0) + if packed_dim == 0: + value = value.transpose(0, 1) - # pack values - if packed_dim == 1: - packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32) - for i in range(pack_factor): - packed |= value[:, i::pack_factor] << num_bits * i - else: - packed = np.zeros((packed_size, value.shape[1]), dtype=np.uint32) - for i in range(pack_factor): - packed |= value[i::pack_factor, :] << num_bits * i + rows, cols = value.shape + padded_cols = math.ceil(cols / pack_factor) * pack_factor + pad_len = padded_cols - cols + + if pad_len > 0: + value = torch.nn.functional.pad(value, (0, pad_len)) + + num_groups = padded_cols // pack_factor + + # Use int32 here + reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32) + + bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits + + packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32) + + if packed_dim == 0: + packed = packed.transpose(0, 1) - # convert back to signed and torch - packed = np.ascontiguousarray(packed).view(np.int32) - return torch.from_numpy(packed) + return packed def unpack_from_int32( diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index e29b8284..a54f8df1 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -25,6 +25,8 @@ from torch import Tensor from tqdm import tqdm +from torch.nn import Module + __all__ = ["BaseSparseCompressor"] @@ -68,6 +70,7 @@ def compress( model_state: Dict[str, Tensor], compression_targets: Optional[Set[str]] = None, show_progress: bool = False, + module: Optional[Module] = None, ) -> Dict[str, Tensor]: """ Compresses a dense state dict using bitmask compression @@ -93,7 +96,10 @@ def compress( if prefix.endswith(".weight"): prefix = prefix[: -(len(".weight"))] - compression_data = self.compress_weight(prefix, value) + compression_data = self.compress_weight( + prefix, value, module=module + ) + for key in compression_data.keys(): if key in compressed_dict: _LOGGER.warn( diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index 7a97faa3..ad12835f 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -52,12 +52,22 @@ def compression_param_names(self) -> Tuple[str]: "bitmask", ) - def compress_weight(self, name, value): + def compress_weight(self, name, value, *, module=None): bitmask_tensor = Sparse24BitMaskTensor.from_dense( value, self.config.sparsity_structure ) - bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu") - return bitmask_dict + if value.device.type == "meta": + if module is None: + raise ValueError("compress_weight requires module argument when is_meta=True") + # Create empty parameter matching compressed shape + empty_weight = torch.empty_like(bitmask_tensor.compressed, device="meta") + module.weight = torch.nn.Parameter(empty_weight, requires_grad=False) + + # Normal flow: return compression dict + return bitmask_tensor.dict( + name_prefix=name, + device="meta" if value.device.type == "meta" else "cpu", + ) def decompress_weight(self, weight_data): data = Sparse24BitMaskTensor.from_compressed_data(**weight_data) @@ -90,9 +100,14 @@ def from_dense( :return: instantiated compressed tensor """ shape = list(tensor.shape) - compressed, bitmask = sparse24_bitmask_compress( - tensor.cpu(), sparsity_structure=sparsity_structure - ) + if tensor.device.type == "meta": + compressed, bitmask = sparse24_bitmask_compress( + tensor, sparsity_structure=sparsity_structure + ) + else: + compressed, bitmask = sparse24_bitmask_compress( + tensor.cpu(), sparsity_structure=sparsity_structure + ) return Sparse24BitMaskTensor( shape=shape, compressed=compressed, @@ -169,6 +184,13 @@ def sparse24_bitmask_compress( SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR ), "Only 2:4 sparsity is supported" + if tensor.device.type == "meta": + num_rows, num_cols = tensor.shape + compressed_values = torch.empty((num_rows, num_cols // 2), dtype=tensor.dtype, device="meta") + packed_cols = (num_cols + 7) // 8 + bitmasks_packed = torch.empty((num_rows, packed_cols), dtype=torch.uint8, device="meta") + return compressed_values, bitmasks_packed + bytemasks = get_24_bytemasks(tensor=tensor) if tensor.dtype == FP8_DTYPE: diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py index 0e08be03..332206d2 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py @@ -46,7 +46,7 @@ def compression_param_names(self) -> Tuple[str]: """ return ("shape", "compressed", "bitmask", "row_offsets") - def compress_weight(self, name, value): + def compress_weight(self, name, value, **kwargs): bitmask_tensor = BitmaskTensor.from_dense(value) bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu") return bitmask_dict From f36f5506434d440928a6b3f098af942ca5423905 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Wed, 2 Jul 2025 09:58:45 -0400 Subject: [PATCH 02/12] remove breakpoint Signed-off-by: shanjiaz --- .../compressors/quantized_compressors/nvfp4_quantized.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 41131589..5f348e91 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -82,7 +82,6 @@ def compress_weight( compressed_dict = {} weight_packed = pack_fp4_to_uint8(quantized_weight) if device is not None: - breakpoint() weight_packed = weight_packed.to(device) compressed_dict["weight_packed"] = weight_packed return compressed_dict From be74a474c6bd078bad47b046aec2df6bd53ce49c Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Wed, 2 Jul 2025 13:14:09 -0400 Subject: [PATCH 03/12] remove comment Signed-off-by: shanjiaz --- .../compressors/model_compressors/model_compressor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index ad9b66cc..c977edd9 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -699,8 +699,7 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module): params_device = next(module.parameters()).device device = "cpu" if has_offloaded_params(module) else params_device delattr(module, param_name) - #requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16) - requires_grad = torch.is_floating_point(data) + requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16) param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad) register_offload_parameter(module, param_name, param) From 1e71b0a65ed992c5a1de5d678c69a2af012ae90a Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Thu, 3 Jul 2025 09:44:40 -0400 Subject: [PATCH 04/12] address reviewed issues Signed-off-by: shanjiaz --- .../compressors/model_compressors/model_compressor.py | 5 ++--- .../compressors/quantized_compressors/base.py | 10 +++++----- .../quantized_compressors/pack_quantized.py | 2 -- .../compressors/sparse_compressors/base.py | 5 +---- .../sparse_compressors/sparse_24_bitmask.py | 9 +-------- 5 files changed, 9 insertions(+), 22 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index c977edd9..0c5de9a6 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -407,7 +407,7 @@ def compress_model(self, model: Module, is_meta: bool = False): state_dict, names_to_scheme=module_to_scheme, show_progress=False, - save_device=exec_device, + compression_device=exec_device, ) # sparsity second @@ -416,7 +416,6 @@ def compress_model(self, model: Module, is_meta: bool = False): state_dict, compression_targets=sparse_compression_targets, show_progress=False, - module=module, ) # remove any existing parameters @@ -700,7 +699,6 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module): device = "cpu" if has_offloaded_params(module) else params_device delattr(module, param_name) requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16) - param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad) register_offload_parameter(module, param_name, param) @@ -718,6 +716,7 @@ def _replace_weights(self, dense_weight_generator, model: Module): 'data' is the updated param data :param model: The model whose weights are to be updated. """ + for mod_path, data in tqdm(dense_weight_generator, desc="Decompressing model"): module = operator.attrgetter(mod_path)(model) diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index b6c0fbcf..302107d5 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -72,7 +72,7 @@ def compress( model_state: Dict[str, Tensor], names_to_scheme: Dict[str, QuantizationScheme], show_progress: bool = False, - save_device: str = "cpu", + compression_device: str = "cpu", **kwargs, ) -> Dict[str, Tensor]: """ @@ -104,7 +104,7 @@ def compress( # is scale does not exist, then weight cannot be compressed if scale is None: - compressed_dict[name] = value.to(save_device) + compressed_dict[name] = value.to(compression_device) continue # compress values on meta if loading from meta otherwise on cpu (memory movement too expensive) @@ -117,12 +117,12 @@ def compress( global_scale=global_scale, g_idx=g_idx, quantization_args=quant_args, - device=save_device, + device=compression_device, ) # update state dict for key, value in compressed_values.items(): - compressed_dict[prefix + key] = value.to(save_device) + compressed_dict[prefix + key] = value.to(compression_device) else: # omit saving zero points for symmetric or packed quantization @@ -133,7 +133,7 @@ def compress( # TODO: does this case actually occur? elif name.endswith("g_idx") and torch.any(value <= -1): continue - compressed_dict[name] = value.to(save_device) + compressed_dict[name] = value.to(compression_device) return compressed_dict diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index 4242207e..d5188d23 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -241,9 +241,7 @@ def pack_to_int32( # Use int32 here reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32) - bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits - packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32) if packed_dim == 0: diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index a54f8df1..34f99a28 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -70,7 +70,6 @@ def compress( model_state: Dict[str, Tensor], compression_targets: Optional[Set[str]] = None, show_progress: bool = False, - module: Optional[Module] = None, ) -> Dict[str, Tensor]: """ Compresses a dense state dict using bitmask compression @@ -96,9 +95,7 @@ def compress( if prefix.endswith(".weight"): prefix = prefix[: -(len(".weight"))] - compression_data = self.compress_weight( - prefix, value, module=module - ) + compression_data = self.compress_weight(prefix, value) for key in compression_data.keys(): if key in compressed_dict: diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index ad12835f..73c8a7ff 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -52,17 +52,10 @@ def compression_param_names(self) -> Tuple[str]: "bitmask", ) - def compress_weight(self, name, value, *, module=None): + def compress_weight(self, name, value): bitmask_tensor = Sparse24BitMaskTensor.from_dense( value, self.config.sparsity_structure ) - if value.device.type == "meta": - if module is None: - raise ValueError("compress_weight requires module argument when is_meta=True") - # Create empty parameter matching compressed shape - empty_weight = torch.empty_like(bitmask_tensor.compressed, device="meta") - module.weight = torch.nn.Parameter(empty_weight, requires_grad=False) - # Normal flow: return compression dict return bitmask_tensor.dict( name_prefix=name, From fe9c1e6389aa40aec562787c5b1bde6ee50561bd Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Thu, 3 Jul 2025 09:46:41 -0400 Subject: [PATCH 05/12] fix style Signed-off-by: shanjiaz --- src/compressed_tensors/compressors/sparse_compressors/base.py | 2 -- .../compressors/sparse_compressors/sparse_24_bitmask.py | 1 - .../compressors/sparse_compressors/sparse_bitmask.py | 2 +- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 34f99a28..8d02f459 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -25,7 +25,6 @@ from torch import Tensor from tqdm import tqdm -from torch.nn import Module __all__ = ["BaseSparseCompressor"] @@ -96,7 +95,6 @@ def compress( prefix = prefix[: -(len(".weight"))] compression_data = self.compress_weight(prefix, value) - for key in compression_data.keys(): if key in compressed_dict: _LOGGER.warn( diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index 73c8a7ff..a1faa779 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -56,7 +56,6 @@ def compress_weight(self, name, value): bitmask_tensor = Sparse24BitMaskTensor.from_dense( value, self.config.sparsity_structure ) - # Normal flow: return compression dict return bitmask_tensor.dict( name_prefix=name, device="meta" if value.device.type == "meta" else "cpu", diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py index 332206d2..0e08be03 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py @@ -46,7 +46,7 @@ def compression_param_names(self) -> Tuple[str]: """ return ("shape", "compressed", "bitmask", "row_offsets") - def compress_weight(self, name, value, **kwargs): + def compress_weight(self, name, value): bitmask_tensor = BitmaskTensor.from_dense(value) bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu") return bitmask_dict From 18aebfc395ead61b6b04cc93bb2e1101a67d95fe Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Thu, 3 Jul 2025 09:47:19 -0400 Subject: [PATCH 06/12] new line Signed-off-by: shanjiaz --- src/compressed_tensors/compressors/sparse_compressors/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 8d02f459..e29b8284 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -26,7 +26,6 @@ from tqdm import tqdm - __all__ = ["BaseSparseCompressor"] _LOGGER: logging.Logger = logging.getLogger(__name__) From e8e7a7dbc38e57a03b8af4e8a729f1ec09911981 Mon Sep 17 00:00:00 2001 From: shanjiaz <43143795+shanjiaz@users.noreply.github.com> Date: Mon, 7 Jul 2025 10:40:43 -0400 Subject: [PATCH 07/12] Update src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py Co-authored-by: Brian Dellabetta --- .../compressors/sparse_compressors/sparse_24_bitmask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index a1faa779..39f4fe83 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -92,7 +92,7 @@ def from_dense( :return: instantiated compressed tensor """ shape = list(tensor.shape) - if tensor.device.type == "meta": + if tensor.is_meta: compressed, bitmask = sparse24_bitmask_compress( tensor, sparsity_structure=sparsity_structure ) From 4ea589bd7a9737b89728319fdec0ecda530a67dd Mon Sep 17 00:00:00 2001 From: shanjiaz <43143795+shanjiaz@users.noreply.github.com> Date: Mon, 7 Jul 2025 10:40:54 -0400 Subject: [PATCH 08/12] Update src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py Co-authored-by: Brian Dellabetta --- .../compressors/sparse_compressors/sparse_24_bitmask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index 39f4fe83..9aa3f682 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -176,7 +176,7 @@ def sparse24_bitmask_compress( SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR ), "Only 2:4 sparsity is supported" - if tensor.device.type == "meta": + if tensor.is_meta: num_rows, num_cols = tensor.shape compressed_values = torch.empty((num_rows, num_cols // 2), dtype=tensor.dtype, device="meta") packed_cols = (num_cols + 7) // 8 From 4a5f064f64b6b3c2ff08a3fa09c470ea880af1ab Mon Sep 17 00:00:00 2001 From: shanjiaz <43143795+shanjiaz@users.noreply.github.com> Date: Mon, 7 Jul 2025 10:41:02 -0400 Subject: [PATCH 09/12] Update src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py Co-authored-by: Brian Dellabetta --- .../compressors/sparse_compressors/sparse_24_bitmask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index 9aa3f682..c9663095 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -58,7 +58,7 @@ def compress_weight(self, name, value): ) return bitmask_tensor.dict( name_prefix=name, - device="meta" if value.device.type == "meta" else "cpu", + device="meta" if value.is_meta else "cpu", ) def decompress_weight(self, weight_data): From 53b63b116fd7d516c5dbe10ae5a88212d1d717a2 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Mon, 7 Jul 2025 12:51:02 -0400 Subject: [PATCH 10/12] Added docstring Signed-off-by: shanjiaz --- .../compressors/model_compressors/model_compressor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 0c5de9a6..ec59d3cf 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -384,6 +384,8 @@ def compress_model(self, model: Module, is_meta: bool = False): this method is more memory-efficient than `self.compress` :param model: model containing parameters to compress + :param is_meta: whether the model is on the meta device, in which case + we do not need move parameters to CPU """ module_to_scheme = map_module_to_scheme(model) sparse_compression_targets: Set[str] = expand_target_names( @@ -488,9 +490,7 @@ def decompress_model(self, model: Module): # replace with decompressed parameters for name, value in state_dict.items(): name = name.removeprefix(f"{prefix}.") - # skipping save if we're just registering the model on meta device - if exec_device != "meta": - value = value.to(exec_device) + value = value.to(exec_device) param = torch.nn.Parameter(value, requires_grad=False) register_offload_parameter(module, name, param, offload_device) From 0916ca5f672d6ebb04cd7a8d3d1079feb87ec34f Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Mon, 7 Jul 2025 14:29:04 -0400 Subject: [PATCH 11/12] removed is_meta input Signed-off-by: shanjiaz --- .../compressors/model_compressors/model_compressor.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index ec59d3cf..d7de413d 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -378,14 +378,12 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]: # ----- model memory compression/decompression pathways ----- # - def compress_model(self, model: Module, is_meta: bool = False): + def compress_model(self, model: Module): """ Compress a model in memory. Because the model structure is modified in place, this method is more memory-efficient than `self.compress` :param model: model containing parameters to compress - :param is_meta: whether the model is on the meta device, in which case - we do not need move parameters to CPU """ module_to_scheme = map_module_to_scheme(model) sparse_compression_targets: Set[str] = expand_target_names( @@ -395,9 +393,13 @@ def compress_model(self, model: Module, is_meta: bool = False): ) for prefix, module in tqdm(model.named_modules(), desc="Compressing model"): + if prefix in module_to_scheme or prefix in sparse_compression_targets: + module_device = get_execution_device(module) + is_meta = (module_device == torch.device("meta")) + exec_device = "meta" if is_meta else "cpu" - onloading_device = "meta" if is_meta else get_execution_device(module) + onloading_device = "meta" if is_meta else module_device # in the future, support compression on same device with align_module_device(module, execution_device=exec_device): From f884e3e9421c3ebb40659312edce9260ca13a100 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Mon, 7 Jul 2025 16:30:36 -0400 Subject: [PATCH 12/12] added test Signed-off-by: shanjiaz --- .../model_compressors/model_compressor.py | 4 +- .../test_model_compressor.py | 60 +++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index d7de413d..bdd01237 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -395,8 +395,8 @@ def compress_model(self, model: Module): for prefix, module in tqdm(model.named_modules(), desc="Compressing model"): if prefix in module_to_scheme or prefix in sparse_compression_targets: - module_device = get_execution_device(module) - is_meta = (module_device == torch.device("meta")) + module_device = get_execution_device(module).type + is_meta = (module_device == "meta") exec_device = "meta" if is_meta else "cpu" onloading_device = "meta" if is_meta else module_device diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index b1d040f4..52d4ad23 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -412,6 +412,66 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir): assert torch.all(compressed[key] == true_compressed[key]), f"{key}" +@pytest.mark.parametrize( + "model_stub,q_format,s_config", + [ + ( + "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed", + "float-quantized", + None, + ), + ( + "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed", + None, + "sparse-24-bitmask", + ), + ( + "nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed", + "float-quantized", + "sparse-24-bitmask", + ), + ( + "nm-testing/llama2.c-stories15M-ultrachat-mixed-uncompressed", + "pack-quantized", + None, + ), + ], +) +def test_compress_model_meta(model_stub, q_format, s_config): + # Load model on CPU to get expected compressed state_dict + cpu_model = AutoModelForCausalLM.from_pretrained( + model_stub, torch_dtype=torch.float32 + ) + reference_compressor = ModelCompressor.from_pretrained_model( + cpu_model, s_config, q_format + ) + # Only stores dtype because meta model does not store values + expected = { + k: v.dtype + for k, v in reference_compressor.compress(cpu_model).items() + } + + # Load model on meta device + meta_model = AutoModelForCausalLM.from_pretrained( + model_stub, + torch_dtype=torch.float32, + low_cpu_mem_usage=True, + ) + for module in meta_model.modules(): + if hasattr(module, "to_empty"): + module.to_empty(device="meta") + + # Compress in-place on meta model + compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, q_format) + compressor.compress_model(meta_model) + + # Compare keys and dtypes + compressed = dict(meta_model.state_dict()) + assert set(compressed.keys()) == set(expected.keys()) + for key, dtype in expected.items(): + assert compressed[key].dtype == dtype, f"{key} has incorrect dtype" + + @pytest.mark.parametrize( "model_stub,comp_stub", [