diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 1e665db43..0ffbd053e 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -45,6 +45,16 @@ def save_checkpoint( get_model_compressor, # avoid circular import ) + # used for decompression + # unfortunately, if skip_sparsity_compression_stats==True, sparsity stats + # are computed twice. In the future, track sparsity from recipe or + # share recipe between compression and decompression + compressor = get_model_compressor( + model=model, + save_compressed=save_compressed, + skip_sparsity_compression_stats=skip_sparsity_compression_stats, + ) + # saving the model also saves the recipe model.save_pretrained( save_path, @@ -55,13 +65,8 @@ def save_checkpoint( if processor is not None: processor.save_pretrained(save_path) - # saving the model modifies the model strcuture + # decompression: saving the model modifies the model strcuture # as this is only a checkpoint, decompress model to enable future training/oneshot - compressor = get_model_compressor( - model=model, - save_compressed=save_compressed, - skip_sparsity_compression_stats=skip_sparsity_compression_stats, - ) if compressor is not None: compressor.decompress_model(model) diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index a6af97b78..07ca385c4 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -363,7 +363,7 @@ def save_model( self, output_dir: str, _internal_call: bool = False, - skip_sparsity_compression_stats: Optional[bool] = False, + skip_sparsity_compression_stats: Optional[bool] = True, ): """ Override of the save_model function and expects it to exist in the parent. @@ -388,6 +388,8 @@ def save_model( self.model.prepare_for_save() # TODO: move to finalize # save checkpoint + # note that skip_sparsity_compression_stats + # is True by default to avoid high runtime cost self.save_state() if self.accelerator.is_main_process: processor = getattr(self, "processing_class", self.tokenizer)