Skip to content

Commit dc2e9b0

Browse files
authored
[Examples] [Bugfix] skip sparsity stats when saving checkpoints (#1528)
## Purpose ## * Fix example failure where sparsity stats were attempted to be computed on a compressed model * This happened because the `get_model_compressor` cannot compute sparsity stats of a compressed model ## Prerequisites ## * neuralmagic/compressed-tensors#345 ## Changes ## * When saving checkpoints, call `get_model_compressor` on the decompressed model (before it is compressed, not after) * Do not save checkpoints with sparsity stats by default, due to high runtime ## Testing ## * Ran example to completion with reduced training duration Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent d947364 commit dc2e9b0

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

src/llmcompressor/pytorch/model_load/helpers.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ def save_checkpoint(
4545
get_model_compressor, # avoid circular import
4646
)
4747

48+
# used for decompression
49+
# unfortunately, if skip_sparsity_compression_stats==True, sparsity stats
50+
# are computed twice. In the future, track sparsity from recipe or
51+
# share recipe between compression and decompression
52+
compressor = get_model_compressor(
53+
model=model,
54+
save_compressed=save_compressed,
55+
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
56+
)
57+
4858
# saving the model also saves the recipe
4959
model.save_pretrained(
5060
save_path,
@@ -55,13 +65,8 @@ def save_checkpoint(
5565
if processor is not None:
5666
processor.save_pretrained(save_path)
5767

58-
# saving the model modifies the model strcuture
68+
# decompression: saving the model modifies the model strcuture
5969
# as this is only a checkpoint, decompress model to enable future training/oneshot
60-
compressor = get_model_compressor(
61-
model=model,
62-
save_compressed=save_compressed,
63-
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
64-
)
6570
if compressor is not None:
6671
compressor.decompress_model(model)
6772

src/llmcompressor/transformers/finetune/session_mixin.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def save_model(
363363
self,
364364
output_dir: str,
365365
_internal_call: bool = False,
366-
skip_sparsity_compression_stats: Optional[bool] = False,
366+
skip_sparsity_compression_stats: Optional[bool] = True,
367367
):
368368
"""
369369
Override of the save_model function and expects it to exist in the parent.
@@ -388,6 +388,8 @@ def save_model(
388388
self.model.prepare_for_save() # TODO: move to finalize
389389

390390
# save checkpoint
391+
# note that skip_sparsity_compression_stats
392+
# is True by default to avoid high runtime cost
391393
self.save_state()
392394
if self.accelerator.is_main_process:
393395
processor = getattr(self, "processing_class", self.tokenizer)

0 commit comments

Comments
 (0)