Skip to content

Commit b30eade

Browse files
committed
deepseekv3
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 6800f81 commit b30eade

File tree

6 files changed

+195
-0
lines changed

6 files changed

+195
-0
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor.modeling import prepare_for_quantization
5+
from llmcompressor.modifiers.quantization import GPTQModifier
6+
from llmcompressor.transformers import oneshot
7+
8+
# Select model and load it.
9+
model_id = "RedHatAI/DeepSeek-V3-BF16"
10+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
11+
tokenizer = AutoTokenizer.from_pretrained(model_id)
12+
model = prepare_for_quantization(model)
13+
14+
# Select calibration dataset.
15+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
16+
DATASET_SPLIT = "train_sft"
17+
18+
# Select number of samples. 512 samples is a good place to start.
19+
# Increasing the number of samples can improve accuracy.
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
# Load dataset and preprocess.
24+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
25+
ds = ds.shuffle(seed=42)
26+
27+
28+
def preprocess(example):
29+
return {
30+
"text": tokenizer.apply_chat_template(
31+
example["messages"],
32+
tokenize=False,
33+
)
34+
}
35+
36+
37+
ds = ds.map(preprocess)
38+
39+
40+
# Tokenize inputs.
41+
def tokenize(sample):
42+
return tokenizer(
43+
sample["text"],
44+
padding=False,
45+
max_length=MAX_SEQUENCE_LENGTH,
46+
truncation=True,
47+
add_special_tokens=False,
48+
)
49+
50+
51+
ds = ds.map(tokenize, remove_columns=ds.column_names)
52+
53+
# Configure the quantization algorithm to run.
54+
# * quantize the weights to 4 bit with GPTQ with a group size 128
55+
recipe = GPTQModifier(
56+
targets="Linear",
57+
scheme="W4A16",
58+
ignore=["lm_head"],
59+
sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"],
60+
)
61+
62+
# Apply algorithms.
63+
oneshot(
64+
model=model,
65+
dataset=ds,
66+
recipe=recipe,
67+
max_seq_length=MAX_SEQUENCE_LENGTH,
68+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
69+
)
70+
71+
# Save to disk compressed.
72+
SAVE_DIR = model_id.split("/")[-1] + "-W4A16-G128"
73+
model.save_pretrained(SAVE_DIR, save_compressed=True)
74+
tokenizer.save_pretrained(SAVE_DIR)
75+
76+
# Load model after saving
77+
model = AutoModelForCausalLM.from_pretrained(SAVE_DIR, device_map="auto")
78+
79+
# Confirm generations of the quantized model look sane.
80+
print("\n\n")
81+
print("========== SAMPLE GENERATION ==============")
82+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
83+
output = model.generate(input_ids, max_new_tokens=100)
84+
print(tokenizer.decode(output[0]))
85+
print("==========================================\n\n")

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from datetime import datetime
33
from typing import TYPE_CHECKING, List, Optional, Union
44

5+
import torch
6+
from compressed_tensors.utils import offloaded_dispatch
57
from loguru import logger
68
from torch.utils.data import DataLoader
79
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin
@@ -127,6 +129,14 @@ def __init__(
127129
# initialize the model and processor
128130
pre_process(model_args)
129131

132+
# offload to cpu if possible
133+
if "cuda" in str(model_args.oneshot_device) and torch.cuda.is_available():
134+
offloaded_dispatch(
135+
model_args.model, execution_device=model_args.oneshot_device
136+
)
137+
else:
138+
logger.warning("CUDA is not available! Compressing model on CPU instead")
139+
130140
# Set instance attributes
131141
self.model = self.model_args.model
132142
self.processor = self.model_args.processor
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .prepare import *
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
3+
4+
5+
class DeepseekV3MoECalibrate(torch.nn.Module):
6+
def __init__(self, config, experts, gate, shared_experts):
7+
super().__init__()
8+
self.config = config
9+
self.experts = experts
10+
self.gate = gate
11+
self.shared_experts = shared_experts
12+
13+
def forward(self, hidden_states):
14+
residuals = hidden_states
15+
orig_shape = hidden_states.shape
16+
topk_indices, topk_weights = self.gate(hidden_states)
17+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
18+
19+
# Begin MoE
20+
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
21+
expert_mask = torch.nn.functional.one_hot(
22+
topk_indices, num_classes=len(self.experts)
23+
)
24+
expert_mask = expert_mask.permute(2, 0, 1)
25+
26+
for expert_idx in range(len(self.experts)):
27+
expert = self.experts[expert_idx]
28+
mask = expert_mask[expert_idx]
29+
token_indices, weight_indices = torch.where(mask)
30+
31+
expert_weights = topk_weights[token_indices, weight_indices]
32+
expert_input = hidden_states[token_indices]
33+
expert_output = expert(expert_input)
34+
weighted_output = expert_output * expert_weights.unsqueeze(-1)
35+
36+
if token_indices.numel() > 0:
37+
final_hidden_states.index_add_(0, token_indices, weighted_output)
38+
# End MoE
39+
40+
hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape)
41+
hidden_states = hidden_states + self.shared_experts(residuals)
42+
return hidden_states
43+
44+
45+
def replace(module: DeepseekV3MoE) -> DeepseekV3MoECalibrate:
46+
return DeepseekV3MoECalibrate(
47+
module.config, module.experts, module.gate, module.shared_experts
48+
)

src/llmcompressor/modeling/prepare.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
from transformers import PreTrainedModel
3+
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
4+
5+
from llmcompressor.modeling.deepseek_v3 import replace as replace_DeepseekV3MoE
6+
from llmcompressor.utils.module import module_bfs
7+
8+
__all__ = ["prepare_for_quantization"]
9+
10+
replacements = {
11+
DeepseekV3MoE: replace_DeepseekV3MoE,
12+
}
13+
14+
15+
def prepare_for_quantization(model: PreTrainedModel) -> PreTrainedModel:
16+
def replace(module: torch.nn.Module) -> torch.nn.Module:
17+
if module.__class__ in replacements:
18+
return replacements[module.__class__](module)
19+
else:
20+
return module
21+
22+
return module_bfs(model, replace, progress=True)

src/llmcompressor/utils/module.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Callable, Union
2+
3+
import torch
4+
import tqdm
5+
6+
__all__ = ["module_bfs"]
7+
8+
9+
def module_bfs(
10+
module: torch.nn.Module,
11+
func: Callable[[torch.nn.Module], torch.nn.Module],
12+
pre: bool = True,
13+
progress: Union[bool, tqdm.tqdm] = False,
14+
) -> torch.nn.Module:
15+
if progress is True:
16+
total = len(list(module.modules()))
17+
progress = tqdm.tqdm(total=total)
18+
if pre:
19+
module = func(module)
20+
for name, child in list(module.named_children()):
21+
module.add_module(name, module_bfs(child, func, pre, progress))
22+
if not pre:
23+
module = func(module)
24+
if isinstance(progress, tqdm.tqdm):
25+
progress.update(1)
26+
27+
return module

0 commit comments

Comments
 (0)