diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 700c1769..bdd01237 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -393,9 +393,16 @@ def compress_model(self, model: Module): ) for prefix, module in tqdm(model.named_modules(), desc="Compressing model"): + if prefix in module_to_scheme or prefix in sparse_compression_targets: + module_device = get_execution_device(module).type + is_meta = (module_device == "meta") + + exec_device = "meta" if is_meta else "cpu" + onloading_device = "meta" if is_meta else module_device + # in the future, support compression on same device - with align_module_device(module, execution_device="cpu"): + with align_module_device(module, execution_device=exec_device): state_dict = module.state_dict(prefix=f"{prefix}.") # quantization first @@ -404,6 +411,7 @@ def compress_model(self, model: Module): state_dict, names_to_scheme=module_to_scheme, show_progress=False, + compression_device=exec_device, ) # sparsity second @@ -415,7 +423,6 @@ def compress_model(self, model: Module): ) # remove any existing parameters - exec_device = get_execution_device(module) offload_device = get_offloaded_device(module) for name, _ in list(module.named_parameters()): delete_offload_parameter(module, name) @@ -423,7 +430,7 @@ def compress_model(self, model: Module): # replace with compressed parameters for name, value in state_dict.items(): name = name.removeprefix(f"{prefix}.") - value = value.to(exec_device) + value = value.to(onloading_device) param = torch.nn.Parameter(value, requires_grad=False) register_offload_parameter(module, name, param, offload_device) diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index d0a07302..302107d5 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -72,6 +72,7 @@ def compress( model_state: Dict[str, Tensor], names_to_scheme: Dict[str, QuantizationScheme], show_progress: bool = False, + compression_device: str = "cpu", **kwargs, ) -> Dict[str, Tensor]: """ @@ -85,7 +86,6 @@ def compress( """ uncompressed_names = list(model_state.keys()) compressed_dict = {} - save_device = "cpu" # compress values desc = "Compressing with quantization" @@ -104,10 +104,10 @@ def compress( # is scale does not exist, then weight cannot be compressed if scale is None: - compressed_dict[name] = value.to(save_device) + compressed_dict[name] = value.to(compression_device) continue - # compress values on cpu (memory movement too expensive) + # compress values on meta if loading from meta otherwise on cpu (memory movement too expensive) module_path = prefix[:-1] if prefix.endswith(".") else prefix quant_args = names_to_scheme[module_path].weights compressed_values = self.compress_weight( @@ -117,12 +117,12 @@ def compress( global_scale=global_scale, g_idx=g_idx, quantization_args=quant_args, - device="cpu", + device=compression_device, ) # update state dict for key, value in compressed_values.items(): - compressed_dict[prefix + key] = value.to(save_device) + compressed_dict[prefix + key] = value.to(compression_device) else: # omit saving zero points for symmetric or packed quantization @@ -133,8 +133,7 @@ def compress( # TODO: does this case actually occur? elif name.endswith("g_idx") and torch.any(value <= -1): continue - - compressed_dict[name] = value.to(save_device) + compressed_dict[name] = value.to(compression_device) return compressed_dict diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index 902ac26c..d5188d23 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -220,30 +220,34 @@ def pack_to_int32( if num_bits < 1: raise ValueError(f"num_bits must be at least 1, got {num_bits}") - # convert to unsigned for packing + # Convert to unsigned range for packing, matching quantization offset offset = 1 << (num_bits - 1) value = (value + offset).to(torch.uint8) - value = value.cpu().numpy().astype(np.uint32) + device = value.device + pack_factor = 32 // num_bits - # pad input tensor and initialize packed output - packed_size = math.ceil(value.shape[packed_dim] / pack_factor) - padding = packed_size * pack_factor - value.shape[packed_dim] - value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0) + if packed_dim == 0: + value = value.transpose(0, 1) - # pack values - if packed_dim == 1: - packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32) - for i in range(pack_factor): - packed |= value[:, i::pack_factor] << num_bits * i - else: - packed = np.zeros((packed_size, value.shape[1]), dtype=np.uint32) - for i in range(pack_factor): - packed |= value[i::pack_factor, :] << num_bits * i + rows, cols = value.shape + padded_cols = math.ceil(cols / pack_factor) * pack_factor + pad_len = padded_cols - cols + + if pad_len > 0: + value = torch.nn.functional.pad(value, (0, pad_len)) + + num_groups = padded_cols // pack_factor + + # Use int32 here + reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32) + bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits + packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32) + + if packed_dim == 0: + packed = packed.transpose(0, 1) - # convert back to signed and torch - packed = np.ascontiguousarray(packed).view(np.int32) - return torch.from_numpy(packed) + return packed def unpack_from_int32( diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index 7a97faa3..c9663095 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -56,8 +56,10 @@ def compress_weight(self, name, value): bitmask_tensor = Sparse24BitMaskTensor.from_dense( value, self.config.sparsity_structure ) - bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu") - return bitmask_dict + return bitmask_tensor.dict( + name_prefix=name, + device="meta" if value.is_meta else "cpu", + ) def decompress_weight(self, weight_data): data = Sparse24BitMaskTensor.from_compressed_data(**weight_data) @@ -90,9 +92,14 @@ def from_dense( :return: instantiated compressed tensor """ shape = list(tensor.shape) - compressed, bitmask = sparse24_bitmask_compress( - tensor.cpu(), sparsity_structure=sparsity_structure - ) + if tensor.is_meta: + compressed, bitmask = sparse24_bitmask_compress( + tensor, sparsity_structure=sparsity_structure + ) + else: + compressed, bitmask = sparse24_bitmask_compress( + tensor.cpu(), sparsity_structure=sparsity_structure + ) return Sparse24BitMaskTensor( shape=shape, compressed=compressed, @@ -169,6 +176,13 @@ def sparse24_bitmask_compress( SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR ), "Only 2:4 sparsity is supported" + if tensor.is_meta: + num_rows, num_cols = tensor.shape + compressed_values = torch.empty((num_rows, num_cols // 2), dtype=tensor.dtype, device="meta") + packed_cols = (num_cols + 7) // 8 + bitmasks_packed = torch.empty((num_rows, packed_cols), dtype=torch.uint8, device="meta") + return compressed_values, bitmasks_packed + bytemasks = get_24_bytemasks(tensor=tensor) if tensor.dtype == FP8_DTYPE: diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index b1d040f4..52d4ad23 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -412,6 +412,66 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir): assert torch.all(compressed[key] == true_compressed[key]), f"{key}" +@pytest.mark.parametrize( + "model_stub,q_format,s_config", + [ + ( + "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed", + "float-quantized", + None, + ), + ( + "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed", + None, + "sparse-24-bitmask", + ), + ( + "nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed", + "float-quantized", + "sparse-24-bitmask", + ), + ( + "nm-testing/llama2.c-stories15M-ultrachat-mixed-uncompressed", + "pack-quantized", + None, + ), + ], +) +def test_compress_model_meta(model_stub, q_format, s_config): + # Load model on CPU to get expected compressed state_dict + cpu_model = AutoModelForCausalLM.from_pretrained( + model_stub, torch_dtype=torch.float32 + ) + reference_compressor = ModelCompressor.from_pretrained_model( + cpu_model, s_config, q_format + ) + # Only stores dtype because meta model does not store values + expected = { + k: v.dtype + for k, v in reference_compressor.compress(cpu_model).items() + } + + # Load model on meta device + meta_model = AutoModelForCausalLM.from_pretrained( + model_stub, + torch_dtype=torch.float32, + low_cpu_mem_usage=True, + ) + for module in meta_model.modules(): + if hasattr(module, "to_empty"): + module.to_empty(device="meta") + + # Compress in-place on meta model + compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, q_format) + compressor.compress_model(meta_model) + + # Compare keys and dtypes + compressed = dict(meta_model.state_dict()) + assert set(compressed.keys()) == set(expected.keys()) + for key, dtype in expected.items(): + assert compressed[key].dtype == dtype, f"{key} has incorrect dtype" + + @pytest.mark.parametrize( "model_stub,comp_stub", [