Skip to content

Commit 4adb7e6

Browse files
kylesayrsdsikka
andcommitted
[MoE] DeepSeek-V3/R1 (#1535)
## Purpose ## * Support DeepSeek-V3 and R1 * Update MoE examples to reflect the current state of MoE models * Share information about sequential onloading and deepseekv3 in readme ## Fixes ## * Fixes #1482 * Fixes #1274 * Fixes #1203 ## Changes ## * Add readme blurb and sequential onloading and deepseek r1 * Add example for R1 * Add a `prepare_for_calibration` method which replaces the MoE module with a module which calibrates all experts with all tokens (but still gates expert outputs as the model would normally) * In the future we can make this method more configurable to support * Sending all tokens to all experts * Using inference-time activations vs train-time activations ## Testing ## * Ran deepseek r1 example to completion --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent 31cffb0 commit 4adb7e6

File tree

5 files changed

+162
-0
lines changed

5 files changed

+162
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
Big updates have landed in LLM Compressor! Check out these exciting new features:
1818

19+
* **Large Model Support with Sequential Onloading** As of llm-compressor>=0.6.0, you can now quantize very large language models on a single GPU. Models are broken into disjoint layers which are then onloaded to the GPU one layer at a time. For more information on sequential onloading, see [Big Modeling with Sequential Onloading](examples/big_models_with_sequential_onloading/README.md) as well as the [DeepSeek-R1 Example](examples/quantizing_moe/deepseek_r1_example.py).
1920
* **Preliminary FP4 Quantization Support:** Quantize weights and activations to FP4 and seamlessly run the compressed model in vLLM. Model weights and activations are quantized following the NVFP4 [configuration](https://github.com/neuralmagic/compressed-tensors/blob/f5dbfc336b9c9c361b9fe7ae085d5cb0673e56eb/src/compressed_tensors/quantization/quant_scheme.py#L104). See examples of [weight-only quantization](examples/quantization_w4a16_fp4/llama3_example.py) and [fp4 activation support](examples/quantization_w4a4_fp4/llama3_example.py). Support is currently preliminary and additional support will be added for MoEs.
2021
* **Axolotl Sparse Finetuning Integration:** Seamlessly finetune sparse LLMs with our Axolotl integration. Learn how to create [fast sparse open-source models with Axolotl and LLM Compressor](https://developers.redhat.com/articles/2025/06/17/axolotl-meets-llm-compressor-fast-sparse-open). See also the [Axolotl integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#llmcompressor).
2122
* **AutoAWQ Integration:** Perform low-bit weight-only quantization efficiently using AutoAWQ, now part of LLM Compressor. *Note: This integration should be considered experimental for now. Enhanced support, including for MoE models and improved handling of larger models via layer sequential pipelining, is planned for upcoming releases.* [See the details](https://github.com/vllm-project/llm-compressor/pull/1177).
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from datasets import load_dataset
2+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor.modeling import prepare_for_calibration
5+
from llmcompressor.modifiers.quantization import GPTQModifier
6+
from llmcompressor.transformers import oneshot
7+
8+
# Select model and load it.
9+
10+
# This script takes about 48 hours on 1xA100 to complete.
11+
# Future improvements will reduce this runtime (#1561, #1558).
12+
13+
# For DeepSeek-R1, we require a full precision model in order to properly calibrate
14+
# `DeepSeek-R1-0528-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16
15+
16+
model_id = "unsloth/DeepSeek-R1-0528-BF16"
17+
config = AutoConfig.from_pretrained(model_id)
18+
del config.quantization_config # fp8 qconfig no longer appplies to bf16 model
19+
model = AutoModelForCausalLM.from_pretrained(
20+
model_id, torch_dtype="auto", config=config
21+
)
22+
tokenizer = AutoTokenizer.from_pretrained(model_id)
23+
model = prepare_for_calibration(model)
24+
25+
# Select calibration dataset.
26+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
27+
DATASET_SPLIT = "train_sft"
28+
29+
# Select number of samples. 512 samples is a good place to start.
30+
# Increasing the number of samples can improve accuracy.
31+
NUM_CALIBRATION_SAMPLES = 512
32+
MAX_SEQUENCE_LENGTH = 2048
33+
34+
# Load dataset and preprocess.
35+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
36+
ds = ds.shuffle(seed=42)
37+
38+
39+
def preprocess(example):
40+
return {
41+
"text": tokenizer.apply_chat_template(
42+
example["messages"],
43+
tokenize=False,
44+
)
45+
}
46+
47+
48+
ds = ds.map(preprocess)
49+
50+
51+
# Tokenize inputs.
52+
def tokenize(sample):
53+
return tokenizer(
54+
sample["text"],
55+
padding=False,
56+
max_length=MAX_SEQUENCE_LENGTH,
57+
truncation=True,
58+
add_special_tokens=False,
59+
)
60+
61+
62+
ds = ds.map(tokenize, remove_columns=ds.column_names)
63+
64+
# Configure the quantization algorithm to run.
65+
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
66+
# list so they remain at full precision
67+
recipe = GPTQModifier(
68+
targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"]
69+
)
70+
71+
# Apply algorithms.
72+
# due to the large size of DeepSeekV3, we specify sequential targets such that
73+
# only one MLP is loaded into GPU memory at a time
74+
oneshot(
75+
model=model,
76+
dataset=ds,
77+
recipe=recipe,
78+
max_seq_length=MAX_SEQUENCE_LENGTH,
79+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
80+
sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"],
81+
)
82+
83+
# Save to disk compressed.
84+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
85+
model.save_pretrained(SAVE_DIR, save_compressed=True)
86+
tokenizer.save_pretrained(SAVE_DIR)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .prepare import *
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
3+
4+
5+
class DeepseekV3MoECalibrate(torch.nn.Module):
6+
"""
7+
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
8+
"""
9+
10+
def __init__(self, config, experts, gate, shared_experts):
11+
super().__init__()
12+
self.config = config
13+
self.experts = experts
14+
self.gate = gate
15+
self.shared_experts = shared_experts
16+
17+
def forward(self, hidden_states):
18+
residuals = hidden_states
19+
orig_shape = hidden_states.shape
20+
topk_indices, topk_weights = self.gate(hidden_states)
21+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
22+
23+
# Begin MoE
24+
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
25+
expert_mask = torch.nn.functional.one_hot(
26+
topk_indices, num_classes=len(self.experts)
27+
)
28+
expert_mask = expert_mask.permute(2, 0, 1)
29+
30+
for expert_idx in range(len(self.experts)):
31+
expert = self.experts[expert_idx]
32+
mask = expert_mask[expert_idx]
33+
token_indices, weight_indices = torch.where(mask)
34+
35+
expert_weights = topk_weights[token_indices, weight_indices]
36+
expert_input = hidden_states[token_indices]
37+
expert_output = expert(expert_input)
38+
weighted_output = expert_output * expert_weights.unsqueeze(-1)
39+
40+
if token_indices.numel() > 0:
41+
final_hidden_states.index_add_(0, token_indices, weighted_output)
42+
# End MoE
43+
44+
hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape)
45+
hidden_states = hidden_states + self.shared_experts(residuals)
46+
return hidden_states
47+
48+
49+
def replace(module: DeepseekV3MoE) -> DeepseekV3MoECalibrate:
50+
return DeepseekV3MoECalibrate(
51+
module.config, module.experts, module.gate, module.shared_experts
52+
)

src/llmcompressor/modeling/prepare.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from compressed_tensors.utils import replace_module
2+
from transformers import PreTrainedModel
3+
4+
from llmcompressor.modeling.deepseek_v3 import replace as replace_DeepseekV3MoE
5+
6+
__all__ = ["prepare_for_calibration"]
7+
8+
replacements = {
9+
"DeepseekV3MoE": replace_DeepseekV3MoE,
10+
}
11+
12+
13+
def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
14+
for name, module in model.named_modules():
15+
cls_name = module.__class__.__name__
16+
if cls_name in replacements:
17+
new_module = replacements[cls_name](module)
18+
replace_module(model, name, new_module)
19+
20+
return model

0 commit comments

Comments
 (0)