Skip to content

Commit 6696810

Browse files
committed
added support for compression on meta device
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent f5b3e71 commit 6696810

File tree

7 files changed

+78
-37
lines changed

7 files changed

+78
-37
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
378378

379379
# ----- model memory compression/decompression pathways ----- #
380380

381-
def compress_model(self, model: Module):
381+
def compress_model(self, model: Module, is_meta: bool = False):
382382
"""
383383
Compress a model in memory. Because the model structure is modified in place,
384384
this method is more memory-efficient than `self.compress`
@@ -394,8 +394,11 @@ def compress_model(self, model: Module):
394394

395395
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
396396
if prefix in module_to_scheme or prefix in sparse_compression_targets:
397+
exec_device = "meta" if is_meta else "cpu"
398+
onloading_device = "meta" if is_meta else get_execution_device(module)
399+
397400
# in the future, support compression on same device
398-
with align_module_device(module, execution_device="cpu"):
401+
with align_module_device(module, execution_device=exec_device):
399402
state_dict = module.state_dict(prefix=f"{prefix}.")
400403

401404
# quantization first
@@ -404,6 +407,7 @@ def compress_model(self, model: Module):
404407
state_dict,
405408
names_to_scheme=module_to_scheme,
406409
show_progress=False,
410+
save_device=exec_device,
407411
)
408412

409413
# sparsity second
@@ -412,18 +416,18 @@ def compress_model(self, model: Module):
412416
state_dict,
413417
compression_targets=sparse_compression_targets,
414418
show_progress=False,
419+
module=module,
415420
)
416421

417422
# remove any existing parameters
418-
exec_device = get_execution_device(module)
419423
offload_device = get_offloaded_device(module)
420424
for name, _ in list(module.named_parameters()):
421425
delete_offload_parameter(module, name)
422426

423427
# replace with compressed parameters
424428
for name, value in state_dict.items():
425429
name = name.removeprefix(f"{prefix}.")
426-
value = value.to(exec_device)
430+
value = value.to(onloading_device)
427431
param = torch.nn.Parameter(value, requires_grad=False)
428432
register_offload_parameter(module, name, param, offload_device)
429433

@@ -485,7 +489,9 @@ def decompress_model(self, model: Module):
485489
# replace with decompressed parameters
486490
for name, value in state_dict.items():
487491
name = name.removeprefix(f"{prefix}.")
488-
value = value.to(exec_device)
492+
# skipping save if we're just registering the model on meta device
493+
if exec_device != "meta":
494+
value = value.to(exec_device)
489495
param = torch.nn.Parameter(value, requires_grad=False)
490496
register_offload_parameter(module, name, param, offload_device)
491497

@@ -693,7 +699,9 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module):
693699
params_device = next(module.parameters()).device
694700
device = "cpu" if has_offloaded_params(module) else params_device
695701
delattr(module, param_name)
696-
requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16)
702+
#requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16)
703+
requires_grad = torch.is_floating_point(data)
704+
697705
param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad)
698706
register_offload_parameter(module, param_name, param)
699707

@@ -711,7 +719,6 @@ def _replace_weights(self, dense_weight_generator, model: Module):
711719
'data' is the updated param data
712720
:param model: The model whose weights are to be updated.
713721
"""
714-
715722
for mod_path, data in tqdm(dense_weight_generator, desc="Decompressing model"):
716723
module = operator.attrgetter(mod_path)(model)
717724

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def compress(
7272
model_state: Dict[str, Tensor],
7373
names_to_scheme: Dict[str, QuantizationScheme],
7474
show_progress: bool = False,
75+
save_device: str = "cpu",
7576
**kwargs,
7677
) -> Dict[str, Tensor]:
7778
"""
@@ -85,7 +86,6 @@ def compress(
8586
"""
8687
uncompressed_names = list(model_state.keys())
8788
compressed_dict = {}
88-
save_device = "cpu"
8989

9090
# compress values
9191
desc = "Compressing with quantization"
@@ -107,7 +107,7 @@ def compress(
107107
compressed_dict[name] = value.to(save_device)
108108
continue
109109

110-
# compress values on cpu (memory movement too expensive)
110+
# compress values on meta if loading from meta otherwise on cpu (memory movement too expensive)
111111
module_path = prefix[:-1] if prefix.endswith(".") else prefix
112112
quant_args = names_to_scheme[module_path].weights
113113
compressed_values = self.compress_weight(
@@ -117,7 +117,7 @@ def compress(
117117
global_scale=global_scale,
118118
g_idx=g_idx,
119119
quantization_args=quant_args,
120-
device="cpu",
120+
device=save_device,
121121
)
122122

123123
# update state dict
@@ -133,7 +133,6 @@ def compress(
133133
# TODO: does this case actually occur?
134134
elif name.endswith("g_idx") and torch.any(value <= -1):
135135
continue
136-
137136
compressed_dict[name] = value.to(save_device)
138137

139138
return compressed_dict

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def compress_weight(
8282
compressed_dict = {}
8383
weight_packed = pack_fp4_to_uint8(quantized_weight)
8484
if device is not None:
85+
breakpoint()
8586
weight_packed = weight_packed.to(device)
8687
compressed_dict["weight_packed"] = weight_packed
8788
return compressed_dict

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -220,30 +220,36 @@ def pack_to_int32(
220220
if num_bits < 1:
221221
raise ValueError(f"num_bits must be at least 1, got {num_bits}")
222222

223-
# convert to unsigned for packing
223+
# Convert to unsigned range for packing, matching quantization offset
224224
offset = 1 << (num_bits - 1)
225225
value = (value + offset).to(torch.uint8)
226-
value = value.cpu().numpy().astype(np.uint32)
226+
device = value.device
227+
227228
pack_factor = 32 // num_bits
228229

229-
# pad input tensor and initialize packed output
230-
packed_size = math.ceil(value.shape[packed_dim] / pack_factor)
231-
padding = packed_size * pack_factor - value.shape[packed_dim]
232-
value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)
230+
if packed_dim == 0:
231+
value = value.transpose(0, 1)
233232

234-
# pack values
235-
if packed_dim == 1:
236-
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
237-
for i in range(pack_factor):
238-
packed |= value[:, i::pack_factor] << num_bits * i
239-
else:
240-
packed = np.zeros((packed_size, value.shape[1]), dtype=np.uint32)
241-
for i in range(pack_factor):
242-
packed |= value[i::pack_factor, :] << num_bits * i
233+
rows, cols = value.shape
234+
padded_cols = math.ceil(cols / pack_factor) * pack_factor
235+
pad_len = padded_cols - cols
236+
237+
if pad_len > 0:
238+
value = torch.nn.functional.pad(value, (0, pad_len))
239+
240+
num_groups = padded_cols // pack_factor
241+
242+
# Use int32 here
243+
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
244+
245+
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
246+
247+
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)
248+
249+
if packed_dim == 0:
250+
packed = packed.transpose(0, 1)
243251

244-
# convert back to signed and torch
245-
packed = np.ascontiguousarray(packed).view(np.int32)
246-
return torch.from_numpy(packed)
252+
return packed
247253

248254

249255
def unpack_from_int32(

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from torch import Tensor
2626
from tqdm import tqdm
2727

28+
from torch.nn import Module
29+
2830

2931
__all__ = ["BaseSparseCompressor"]
3032

@@ -68,6 +70,7 @@ def compress(
6870
model_state: Dict[str, Tensor],
6971
compression_targets: Optional[Set[str]] = None,
7072
show_progress: bool = False,
73+
module: Optional[Module] = None,
7174
) -> Dict[str, Tensor]:
7275
"""
7376
Compresses a dense state dict using bitmask compression
@@ -93,7 +96,10 @@ def compress(
9396
if prefix.endswith(".weight"):
9497
prefix = prefix[: -(len(".weight"))]
9598

96-
compression_data = self.compress_weight(prefix, value)
99+
compression_data = self.compress_weight(
100+
prefix, value, module=module
101+
)
102+
97103
for key in compression_data.keys():
98104
if key in compressed_dict:
99105
_LOGGER.warn(

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,22 @@ def compression_param_names(self) -> Tuple[str]:
5252
"bitmask",
5353
)
5454

55-
def compress_weight(self, name, value):
55+
def compress_weight(self, name, value, *, module=None):
5656
bitmask_tensor = Sparse24BitMaskTensor.from_dense(
5757
value, self.config.sparsity_structure
5858
)
59-
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
60-
return bitmask_dict
59+
if value.device.type == "meta":
60+
if module is None:
61+
raise ValueError("compress_weight requires module argument when is_meta=True")
62+
# Create empty parameter matching compressed shape
63+
empty_weight = torch.empty_like(bitmask_tensor.compressed, device="meta")
64+
module.weight = torch.nn.Parameter(empty_weight, requires_grad=False)
65+
66+
# Normal flow: return compression dict
67+
return bitmask_tensor.dict(
68+
name_prefix=name,
69+
device="meta" if value.device.type == "meta" else "cpu",
70+
)
6171

6272
def decompress_weight(self, weight_data):
6373
data = Sparse24BitMaskTensor.from_compressed_data(**weight_data)
@@ -90,9 +100,14 @@ def from_dense(
90100
:return: instantiated compressed tensor
91101
"""
92102
shape = list(tensor.shape)
93-
compressed, bitmask = sparse24_bitmask_compress(
94-
tensor.cpu(), sparsity_structure=sparsity_structure
95-
)
103+
if tensor.device.type == "meta":
104+
compressed, bitmask = sparse24_bitmask_compress(
105+
tensor, sparsity_structure=sparsity_structure
106+
)
107+
else:
108+
compressed, bitmask = sparse24_bitmask_compress(
109+
tensor.cpu(), sparsity_structure=sparsity_structure
110+
)
96111
return Sparse24BitMaskTensor(
97112
shape=shape,
98113
compressed=compressed,
@@ -169,6 +184,13 @@ def sparse24_bitmask_compress(
169184
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
170185
), "Only 2:4 sparsity is supported"
171186

187+
if tensor.device.type == "meta":
188+
num_rows, num_cols = tensor.shape
189+
compressed_values = torch.empty((num_rows, num_cols // 2), dtype=tensor.dtype, device="meta")
190+
packed_cols = (num_cols + 7) // 8
191+
bitmasks_packed = torch.empty((num_rows, packed_cols), dtype=torch.uint8, device="meta")
192+
return compressed_values, bitmasks_packed
193+
172194
bytemasks = get_24_bytemasks(tensor=tensor)
173195

174196
if tensor.dtype == FP8_DTYPE:

src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def compression_param_names(self) -> Tuple[str]:
4646
"""
4747
return ("shape", "compressed", "bitmask", "row_offsets")
4848

49-
def compress_weight(self, name, value):
49+
def compress_weight(self, name, value, **kwargs):
5050
bitmask_tensor = BitmaskTensor.from_dense(value)
5151
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
5252
return bitmask_dict

0 commit comments

Comments
 (0)