Skip to content

Commit 82dfe9d

Browse files
committed
revert unrelated change
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent d955b5e commit 82dfe9d

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,14 +370,20 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
370370
# ----- model memory compression/decompression pathways ----- #
371371

372372
def compress_model(self, model: Module):
373+
"""
374+
Compress a model in memory. Because the model structure is modified in place,
375+
this method is more memory-efficient than `self.compress`
376+
377+
:param model: model containing parameters to compress
378+
"""
373379
module_to_scheme = map_module_to_scheme(model)
374380
sparse_compression_targets: Set[str] = expand_target_names(
375381
model=model,
376382
targets=self.sparsity_config.targets if self.sparsity_config else [],
377383
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
378384
)
379385

380-
for prefix, module in model.named_modules():
386+
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
381387
if prefix in module_to_scheme or prefix in sparse_compression_targets:
382388
state_dict = module.state_dict(prefix=f"{prefix}.")
383389
# quantization first
@@ -409,14 +415,20 @@ def compress_model(self, model: Module):
409415
module.quantization_status = QuantizationStatus.COMPRESSED
410416

411417
def decompress_model(self, model: Module):
418+
"""
419+
Decompress a model in memory. Because the model structure is modified in place,
420+
this method does not require loading some compression parameters from disk
421+
422+
:param model: model containing parameters to compress
423+
"""
412424
module_to_scheme = map_module_to_scheme(model)
413425
sparse_compression_targets: Set[str] = expand_target_names(
414426
model=model,
415427
targets=self.sparsity_config.targets if self.sparsity_config else [],
416428
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
417429
)
418430

419-
for prefix, module in model.named_modules():
431+
for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"):
420432
if prefix in module_to_scheme or prefix in sparse_compression_targets:
421433
state_dict = module.state_dict(prefix=f"{prefix}.")
422434
# sparsity first

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
"dequantize",
3838
"fake_quantize",
3939
"wrap_module_forward_quantized",
40-
"unwrap_module_forward_quantized",
4140
"forward_quantize",
4241
]
4342

@@ -313,10 +312,6 @@ def wrapped_forward(self, *args, **kwargs):
313312
setattr(module, "forward", bound_wrapped_forward)
314313

315314

316-
def unwrap_module_forward_quantized(module: Module):
317-
delattr(module, "forward") # revert to class implementation
318-
319-
320315
def forward_quantize(
321316
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
322317
) -> torch.Tensor:

src/compressed_tensors/utils/safetensors_load.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ def get_nested_mappings_from_state_dict(
269269
:param state_dict: state dict of the model
270270
:param params_to_nest: Iterable of parameter names to nest.
271271
:return: Nested mapping of parameterized layer names to the value of
272-
each layer's compression parameters.
272+
each layer's compression parameters. If `return_unmatched_params`, then
273+
also return a dictionary mapping unused parameter names to their values
273274
"""
274275
nested_weight_mappings = {}
275276
unmatched_params = {}

0 commit comments

Comments
 (0)