Skip to content

Commit 421bd61

Browse files
Use model compression pathways (#1419)
## Purpose ## * Use in-memory model compression pathway in order to reduce memory requirements when saving models * These changes along with [postprocessing changes](https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/entrypoints/utils.py#L102) move users towards a pattern where they are aware of the status of the model (frozen/compressed) and call `save_pretrained` manually ## Prerequisites ## * #1449 ## Changes ## * Modify `save_pretrained_wrapper` to use `compress_model(model)` rather than `compress(state_dict)` * Modify `save_pretrained_wrapper` so that the state dict is only retrieved if not skipping compression stats * Modify `save_pretrained_wrapper` to save dictionary and python files, even if there is no explicit compressor * Modify `save_checkpoint` (used by training) to decompress after the checkpoint is saved ## Example/Testing Changes ## As far as I can tell, below lists all of the instances where a model undergoes saving (no immediately followed by script exit) File Path | Solution -- | -- examples/trl_mixin/ex_trl_constant.py <br> test_oneshot_and_finetune.py <br> tests/llmcompressor/transformers/obcq/test_obcq_completion.py | Decompress in between stages examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py <br> test_oneshot_and_finetune_with_tokenizer.py | Do not save in between stages to avoid compressed state test_oneshot_then_finetune.py | No work is required, as model is decompressed upon loading from disk test_compress_tensor_utils.py | Fix test to use `dispatch_model` (which is actually used by transformers) rather than `cpu_offload` ## Testing ## State Dict | In Memory -- | -- ![previous](https://github.com/user-attachments/assets/f661a9a9-f546-4196-bb7c-58e48409d86d) | ![now](https://github.com/user-attachments/assets/b5edb8f9-1bfb-4474-83c4-48d1942f7c53) <details><summary>oneshot_save.py</summary> ```python3 import torch from transformers import AutoModelForCausalLM from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier from pttp import TensorProfiler #MODEL_ID = "DeepSeek-V3_local_bf16" MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" with TensorProfiler() as prof: prof.mark_event("Load model") model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16) prof.mark_event("Oneshot") oneshot( model=model, recipe=QuantizationModifier(targets="Linear", scheme="W4A16"), trust_remote_code_model=True, ) prof.mark_event("Save model") model.save_pretrained("sav_testing", save_compressed=True, skip_compression_stats=True) prof.save_memory_timeline("save_timeline.png") ``` </details> ## Testing ## * Nightly: https://github.com/neuralmagic/llm-compressor-testing/actions/runs/15453075963 --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent b3e728a commit 421bd61

File tree

6 files changed

+54
-45
lines changed

6 files changed

+54
-45
lines changed

examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,13 @@
6868
model=model,
6969
**oneshot_kwargs,
7070
stage="sparsity_stage",
71-
output_dir=output_dir,
7271
)
7372

7473
# Sparse finetune
7574
finetune_applied_model = train(
7675
model=oneshot_applied_model,
7776
**oneshot_kwargs,
7877
**training_kwargs,
79-
output_dir=output_dir,
8078
stage="finetuning_stage",
8179
)
8280

src/llmcompressor/pytorch/model_load/helpers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def save_checkpoint(
4141
:param save_safetensors: save model checkpoint using safetensors file type
4242
:param save_compressed: save model checkpoint using compressed-tensors format
4343
"""
44+
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
45+
get_model_compressor, # avoid circular import
46+
)
47+
4448
# saving the model also saves the recipe
4549
model.save_pretrained(
4650
save_path,
@@ -51,6 +55,16 @@ def save_checkpoint(
5155
if processor is not None:
5256
processor.save_pretrained(save_path)
5357

58+
# saving the model modifies the model strcuture
59+
# as this is only a checkpoint, decompress model to enable future training/oneshot
60+
compressor = get_model_compressor(
61+
model=model,
62+
save_compressed=save_compressed,
63+
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
64+
)
65+
if compressor is not None:
66+
compressor.decompress_model(model)
67+
5468

5569
def fallback_to_cpu(device: str) -> str:
5670
"""

src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import re
33
import weakref
44
from functools import wraps
5-
from typing import Dict, Optional
5+
from typing import Optional
66

77
import torch
88
import transformers
@@ -91,45 +91,27 @@ def save_pretrained_wrapper(
9191
# https://github.com/huggingface/transformers/pull/30488
9292
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
9393

94-
# state_dict gets passed in as a kwarg for FSDP models
95-
state_dict = kwargs.pop("state_dict", None)
96-
if state_dict is None:
97-
logger.info("Fetching state_dict - this may take some time")
98-
state_dict = get_state_dict_offloaded_model(model)
99-
100-
logger.info("Fetching compressor")
94+
# compress model using compressor
10195
compressor = get_model_compressor(
10296
model=model,
10397
sparsity_config=sparsity_config,
10498
quantization_format=quantization_format,
10599
save_compressed=save_compressed,
106100
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
107-
state_dict=state_dict,
108101
disable_sparse_compression=disable_sparse_compression,
109102
)
103+
if compressor is not None:
104+
compressor.compress_model(model)
105+
106+
# save (compressed) model structure
107+
original_save_pretrained.__get__(model, model_class)(
108+
save_directory,
109+
safe_serialization=safe_serialization,
110+
**kwargs,
111+
)
110112

111-
if compressor is None:
112-
# model is not compressed or quantized, save as normal
113-
original_save_pretrained_func = original_save_pretrained.__get__(
114-
model, model_class
115-
)
116-
original_save_pretrained_func(
117-
save_directory, state_dict=state_dict, **kwargs
118-
)
119-
return
120-
121-
# make sure we're on the main process when saving
122-
if state_dict is not None and len(state_dict) > 0:
123-
compressed_state_dict = compressor.compress(
124-
model, state_dict, show_progress=True
125-
)
126-
logger.info("Saving compressed model to disk")
127-
original_save_pretrained.__get__(model, model_class)(
128-
save_directory,
129-
state_dict=compressed_state_dict,
130-
safe_serialization=safe_serialization,
131-
**kwargs,
132-
)
113+
# update config to reflect compression
114+
if compressor is not None:
133115
compressor.update_config(save_directory)
134116

135117
# update existing recipe
@@ -197,7 +179,6 @@ def get_model_compressor(
197179
quantization_format: Optional[str] = None,
198180
save_compressed: bool = True,
199181
skip_sparsity_compression_stats: bool = True,
200-
state_dict: Optional[Dict] = None,
201182
disable_sparse_compression: bool = False,
202183
):
203184
"""
@@ -211,12 +192,8 @@ def get_model_compressor(
211192
:param save_compressed: boolean representing to save in a compressed
212193
format
213194
:param skip_sparsity_compression_stats: bool allowing compression stats on std out
214-
:param state_dict: state_dict of the model
215195
:param disable_sparse_compression: bool to skip sparse compression
216196
"""
217-
# find offloaded state dict if none is provided
218-
if state_dict is None:
219-
state_dict = get_state_dict_offloaded_model(model)
220197

221198
if sparsity_config is None:
222199
"""
@@ -244,6 +221,8 @@ def get_model_compressor(
244221
)
245222
sparsity_config = None
246223
else:
224+
state_dict = get_state_dict_offloaded_model(model)
225+
247226
sparsity_config = SparsityConfigMetadata.from_pretrained(
248227
model,
249228
state_dict=state_dict,

tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from parameterized import parameterized_class
88
from transformers import AutoConfig
99

10+
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
11+
get_model_compressor,
12+
)
1013
from tests.testing_utils import parse_params, requires_gpu
1114

1215
CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/finetune/finetune_oneshot_configs"
@@ -34,17 +37,21 @@ def _test_oneshot_and_finetune(self):
3437
output_dir=self.output,
3538
)
3639

37-
train_args = dict(
38-
num_train_epochs=self.num_train_epochs,
39-
precision="bfloat16",
40-
bf16=True,
41-
)
4240
oneshot_model = oneshot(
4341
model=self.model,
4442
**oneshot_args,
4543
stage="test_oneshot_stage",
4644
)
4745

46+
compressor = get_model_compressor(model=oneshot_model, save_compressed=True)
47+
if compressor is not None:
48+
compressor.decompress_model(oneshot_model)
49+
50+
train_args = dict(
51+
num_train_epochs=self.num_train_epochs,
52+
precision="bfloat16",
53+
bf16=True,
54+
)
4855
train(
4956
model=oneshot_model,
5057
**oneshot_args,

tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def test_oneshot_and_finetune_with_tokenizer(self):
5555
concatenate_data=concatenate_data,
5656
splits=splits,
5757
tokenizer=tokenizer,
58-
output_dir=self.output,
5958
)
6059

6160
oneshot_model = oneshot(
@@ -70,6 +69,7 @@ def test_oneshot_and_finetune_with_tokenizer(self):
7069
max_steps=max_steps,
7170
stage="test_train_stage",
7271
**model_and_data_kwargs,
72+
output_dir=self.output,
7373
)
7474

7575
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(

tests/llmcompressor/transformers/obcq/test_obcq_completion.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def labeled_dataloader(self, dataset_name, model_name):
3535
dataset_manager = TextGenerationDataset.load_from_registry(
3636
dataset_args.dataset,
3737
dataset_args=dataset_args,
38-
split="train",
38+
split=f"train[:{self.num_samples}]",
3939
processor=tokenizer,
4040
)
4141
calib_dataset = dataset_manager()
@@ -51,10 +51,14 @@ def _test_oneshot_completion(self, model_name: str = None):
5151
from llmcompressor import oneshot
5252
from llmcompressor.pytorch.model_load.helpers import get_session_model
5353
from llmcompressor.pytorch.utils import tensors_to_device
54+
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
55+
get_model_compressor, # avoid circular import
56+
)
5457

5558
oneshot(
5659
model=self.model,
5760
dataset=self.dataset,
61+
splits={"calibration": f"train[:{self.num_samples}]"},
5862
oneshot_device=self.device,
5963
recipe=self.recipe,
6064
max_seq_length=512,
@@ -65,6 +69,13 @@ def _test_oneshot_completion(self, model_name: str = None):
6569
)
6670

6771
first_tiny_model = get_session_model()
72+
compressor = get_model_compressor(
73+
model=first_tiny_model,
74+
save_compressed=True,
75+
skip_sparsity_compression_stats=False,
76+
)
77+
if compressor is not None:
78+
compressor.decompress_model(first_tiny_model)
6879

6980
dataset = "open_platypus"
7081

0 commit comments

Comments
 (0)