Skip to content

Commit 1e71b0a

Browse files
committed
address reviewed issues
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent be74a47 commit 1e71b0a

File tree

5 files changed

+9
-22
lines changed

5 files changed

+9
-22
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def compress_model(self, model: Module, is_meta: bool = False):
407407
state_dict,
408408
names_to_scheme=module_to_scheme,
409409
show_progress=False,
410-
save_device=exec_device,
410+
compression_device=exec_device,
411411
)
412412

413413
# sparsity second
@@ -416,7 +416,6 @@ def compress_model(self, model: Module, is_meta: bool = False):
416416
state_dict,
417417
compression_targets=sparse_compression_targets,
418418
show_progress=False,
419-
module=module,
420419
)
421420

422421
# remove any existing parameters
@@ -700,7 +699,6 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module):
700699
device = "cpu" if has_offloaded_params(module) else params_device
701700
delattr(module, param_name)
702701
requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16)
703-
704702
param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad)
705703
register_offload_parameter(module, param_name, param)
706704

@@ -718,6 +716,7 @@ def _replace_weights(self, dense_weight_generator, model: Module):
718716
'data' is the updated param data
719717
:param model: The model whose weights are to be updated.
720718
"""
719+
721720
for mod_path, data in tqdm(dense_weight_generator, desc="Decompressing model"):
722721
module = operator.attrgetter(mod_path)(model)
723722

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def compress(
7272
model_state: Dict[str, Tensor],
7373
names_to_scheme: Dict[str, QuantizationScheme],
7474
show_progress: bool = False,
75-
save_device: str = "cpu",
75+
compression_device: str = "cpu",
7676
**kwargs,
7777
) -> Dict[str, Tensor]:
7878
"""
@@ -104,7 +104,7 @@ def compress(
104104

105105
# is scale does not exist, then weight cannot be compressed
106106
if scale is None:
107-
compressed_dict[name] = value.to(save_device)
107+
compressed_dict[name] = value.to(compression_device)
108108
continue
109109

110110
# compress values on meta if loading from meta otherwise on cpu (memory movement too expensive)
@@ -117,12 +117,12 @@ def compress(
117117
global_scale=global_scale,
118118
g_idx=g_idx,
119119
quantization_args=quant_args,
120-
device=save_device,
120+
device=compression_device,
121121
)
122122

123123
# update state dict
124124
for key, value in compressed_values.items():
125-
compressed_dict[prefix + key] = value.to(save_device)
125+
compressed_dict[prefix + key] = value.to(compression_device)
126126

127127
else:
128128
# omit saving zero points for symmetric or packed quantization
@@ -133,7 +133,7 @@ def compress(
133133
# TODO: does this case actually occur?
134134
elif name.endswith("g_idx") and torch.any(value <= -1):
135135
continue
136-
compressed_dict[name] = value.to(save_device)
136+
compressed_dict[name] = value.to(compression_device)
137137

138138
return compressed_dict
139139

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,7 @@ def pack_to_int32(
241241

242242
# Use int32 here
243243
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
244-
245244
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
246-
247245
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)
248246

249247
if packed_dim == 0:

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def compress(
7070
model_state: Dict[str, Tensor],
7171
compression_targets: Optional[Set[str]] = None,
7272
show_progress: bool = False,
73-
module: Optional[Module] = None,
7473
) -> Dict[str, Tensor]:
7574
"""
7675
Compresses a dense state dict using bitmask compression
@@ -96,9 +95,7 @@ def compress(
9695
if prefix.endswith(".weight"):
9796
prefix = prefix[: -(len(".weight"))]
9897

99-
compression_data = self.compress_weight(
100-
prefix, value, module=module
101-
)
98+
compression_data = self.compress_weight(prefix, value)
10299

103100
for key in compression_data.keys():
104101
if key in compressed_dict:

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,10 @@ def compression_param_names(self) -> Tuple[str]:
5252
"bitmask",
5353
)
5454

55-
def compress_weight(self, name, value, *, module=None):
55+
def compress_weight(self, name, value):
5656
bitmask_tensor = Sparse24BitMaskTensor.from_dense(
5757
value, self.config.sparsity_structure
5858
)
59-
if value.device.type == "meta":
60-
if module is None:
61-
raise ValueError("compress_weight requires module argument when is_meta=True")
62-
# Create empty parameter matching compressed shape
63-
empty_weight = torch.empty_like(bitmask_tensor.compressed, device="meta")
64-
module.weight = torch.nn.Parameter(empty_weight, requires_grad=False)
65-
6659
# Normal flow: return compression dict
6760
return bitmask_tensor.dict(
6861
name_prefix=name,

0 commit comments

Comments
 (0)