Skip to content

Commit f1ec15b

Browse files
dsikkakylesayrsbrian-dellabetta
authored
[Examples] Add an updated llama3.3 model to examples (#1592)
SUMMARY: - Add a llama 3.3 70b to the `big_models_with_sequential_onloading` folder to illustrate large model usage with sequential onloading - Add code details to the ReadMe - Fix formatting issue Next Steps: - Add back an example test testing the example --------- Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent cb8f410 commit f1ec15b

File tree

3 files changed

+121
-5
lines changed

3 files changed

+121
-5
lines changed
Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
## Big Modeling with Sequential Onloading ##
2-
### What is Sequential Onloading? ###
1+
# Big Modeling with Sequential Onloading #
2+
## What is Sequential Onloading? ##
33
Sequential onloading is a memory-efficient approach for compressing large language models (LLMs) using only a single GPU. Instead of loading the entire model into memory—which can easily require hundreds of gigabytes—this method loads and compresses one layer at a time. The outputs are offloaded before the next layer is processed, dramatically reducing peak memory usage while maintaining high compression fidelity.
44

55
<p align="center">
@@ -8,5 +8,38 @@ Sequential onloading is a memory-efficient approach for compressing large langua
88

99
For more information, see the [RedHat AI blog post](https://developers.redhat.com/articles/2025/05/09/llm-compressor-optimize-llms-low-latency-deployments#generalizing_to_multimodal_and_moe_architectures) or the [LLM Compressor Office Hours Recording](https://www.youtube.com/watch?v=GrhuqQDmBk8).
1010

11-
### Using Sequential Onloading ###
12-
Sequential onloading is enabled by default within LLM Compressor. To disable sequential onloading, add the `pipeline="basic"` argument to the LLM Compressor `oneshot` function call.
11+
## Using Sequential Onloading ##
12+
Sequential onloading is enabled by default within LLM Compressor. To disable sequential onloading, add the `pipeline="basic"` argument to the LLM Compressor `oneshot` function call.
13+
14+
## Running Llama 3.3 70b ##
15+
The Llama 3.3 70b is larger than 80 GB, surpassing the size of 1 A100. However, with sequential onloading, this model can still be quantized seamlessly using a single GPU.
16+
17+
### Code Walkthough
18+
19+
```python
20+
model_id = "meta-llama/Llama-3.3-70B-Instruct"
21+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
22+
```
23+
24+
The model is first loaded onto the `cpu`, as indicated through the use of `None` for the `device_map` argument in the `from_pretrained` method when loading the model.
25+
26+
```python
27+
oneshot(
28+
model=model,
29+
dataset=ds,
30+
recipe=recipe,
31+
max_seq_length=MAX_SEQUENCE_LENGTH,
32+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
33+
)
34+
```
35+
During `oneshot`, only one gpu is required which will be used to onload each layer for calibration in a sequential manner.
36+
37+
```python
38+
dispatch_for_generation(model)
39+
sample = tokenizer("Hello my name is", return_tensors="pt")
40+
sample = {key: value.to("cuda") for key, value in sample.items()}
41+
output = model.generate(**sample, max_new_tokens=100)
42+
print(tokenizer.decode(output[0]))
43+
```
44+
45+
Finally, we call `dispatch_for_generation` to evenly load the model across available devices (potentially offloading the model if required) and run sample generations on the newly quantized model.
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor.modifiers.quantization import GPTQModifier
5+
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
6+
from llmcompressor.transformers import oneshot
7+
from llmcompressor.utils import dispatch_for_generation
8+
9+
# Select model and load it.
10+
model_id = "meta-llama/Llama-3.3-70B-Instruct"
11+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
12+
tokenizer = AutoTokenizer.from_pretrained(model_id)
13+
14+
# Select calibration dataset.
15+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
16+
DATASET_SPLIT = "train_sft"
17+
18+
# Select number of samples. 512 samples is a good place to start.
19+
# Increasing the number of samples can improve accuracy.
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
# Load dataset and preprocess.
24+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
25+
ds = ds.shuffle(seed=42)
26+
27+
28+
def preprocess(example):
29+
return {
30+
"text": tokenizer.apply_chat_template(
31+
example["messages"],
32+
tokenize=False,
33+
)
34+
}
35+
36+
37+
ds = ds.map(preprocess)
38+
39+
40+
# Tokenize inputs.
41+
def tokenize(sample):
42+
return tokenizer(
43+
sample["text"],
44+
padding=False,
45+
max_length=MAX_SEQUENCE_LENGTH,
46+
truncation=True,
47+
add_special_tokens=False,
48+
)
49+
50+
51+
ds = ds.map(tokenize, remove_columns=ds.column_names)
52+
53+
# Configure the quantization algorithm to run.
54+
# * apply SmoothQuant to make the activations easier to quantize
55+
# * quantize the weights to int8 with GPTQ (static per channel)
56+
# * quantize the activations to int8 (dynamic per token)
57+
recipe = [
58+
SmoothQuantModifier(smoothing_strength=0.8),
59+
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
60+
]
61+
# Apply algorithms.
62+
oneshot(
63+
model=model,
64+
dataset=ds,
65+
recipe=recipe,
66+
max_seq_length=MAX_SEQUENCE_LENGTH,
67+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
68+
)
69+
70+
# Confirm generations of the quantized model look sane.
71+
print("\n\n")
72+
print("========== SAMPLE GENERATION ==============")
73+
dispatch_for_generation(model)
74+
sample = tokenizer("Hello my name is", return_tensors="pt")
75+
sample = {key: value.to("cuda") for key, value in sample.items()}
76+
output = model.generate(**sample, max_new_tokens=100)
77+
print(tokenizer.decode(output[0]))
78+
print("==========================================\n\n")
79+
80+
# Save to disk compressed.
81+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W8A8"
82+
model.save_pretrained(SAVE_DIR, save_compressed=True)
83+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def initialize_observer(
7777
maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK),
7878
patience=observer_kwargs.get("patience", DEFAULT_PATIENCE),
7979
grid=observer_kwargs.get("grid", DEFAULT_GRID),
80-
norm=observer_kwargs.get("norm", DEFAULT_NORM)
80+
norm=observer_kwargs.get("norm", DEFAULT_NORM),
8181
)
8282
module.register_module(f"{base_name}_observer", observer)
8383

0 commit comments

Comments
 (0)