Skip to content

Commit 0047b8c

Browse files
kylesayrsdsikka
andauthored
[Llama4] Update MoE Support for LLama4 + Add NVFP4 and W4A16 Examples (#1608)
Summary: - Updates prepare method to no longer require a replace function but just pass in the orignal module directly along with the text config - Add llama4 calibration support - swaps `Llama4TextMoe` with `SequentialLlama4TextMoe` modules - Add llama4 example for NVFP4 and W4A16 Testing - Tested llama4 NVFP4 e2e to produce: `nm-testing/Llama-4-Scout-17B-16E-Instruct-NVFP4` --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent b457898 commit 0047b8c

File tree

5 files changed

+268
-12
lines changed

5 files changed

+268
-12
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
from datasets import load_dataset
3+
from transformers import Llama4ForConditionalGeneration, Llama4Processor
4+
5+
from llmcompressor import oneshot
6+
from llmcompressor.modeling import prepare_for_calibration
7+
from llmcompressor.modifiers.quantization import GPTQModifier
8+
9+
# Select model and load it.
10+
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
11+
model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
12+
processor = Llama4Processor.from_pretrained(model_id)
13+
# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`.
14+
# This change allows compatibility with vllm.
15+
# To apply your own custom module for experimentation, consider updating
16+
# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
17+
model = prepare_for_calibration(model)
18+
19+
DATASET_ID = "neuralmagic/calibration"
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 8192
22+
23+
ds = load_dataset(DATASET_ID, name="LLM", split=f"train[:{NUM_CALIBRATION_SAMPLES}]")
24+
25+
26+
def preprocess_function(example):
27+
messgages = []
28+
for message in example["messages"]:
29+
messgages.append(
30+
{
31+
"role": message["role"],
32+
"content": [{"type": "text", "text": message["content"]}],
33+
}
34+
)
35+
36+
return processor.apply_chat_template(
37+
messgages,
38+
return_tensors="pt",
39+
padding=False,
40+
truncation=True,
41+
max_length=MAX_SEQUENCE_LENGTH,
42+
tokenize=True,
43+
add_special_tokens=False,
44+
return_dict=True,
45+
add_generation_prompt=False,
46+
)
47+
48+
49+
ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names)
50+
51+
52+
def data_collator(batch):
53+
assert len(batch) == 1
54+
return {
55+
key: torch.tensor(value)
56+
if key != "pixel_values"
57+
else torch.tensor(value, dtype=torch.bfloat16).squeeze(0)
58+
for key, value in batch[0].items()
59+
}
60+
61+
62+
# Configure the quantization algorithm to run.
63+
recipe = GPTQModifier(
64+
targets="Linear",
65+
scheme="W4A16",
66+
ignore=[
67+
"re:.*lm_head",
68+
"re:.*self_attn",
69+
"re:.*router",
70+
"re:vision_model.*",
71+
"re:multi_modal_projector.*",
72+
"Llama4TextAttention",
73+
],
74+
)
75+
76+
# Apply algorithms.
77+
# due to the large size of Llama4, we specify sequential targets such that
78+
# only one MLP is loaded into GPU memory at a time
79+
oneshot(
80+
model=model,
81+
dataset=ds,
82+
recipe=recipe,
83+
max_seq_length=MAX_SEQUENCE_LENGTH,
84+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
85+
data_collator=data_collator,
86+
sequential_targets=["Llama4TextMLP"],
87+
)
88+
89+
# Save to disk compressed.
90+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
91+
model.save_pretrained(SAVE_DIR, save_compressed=True)
92+
processor.save_pretrained(SAVE_DIR)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import torch
2+
from datasets import load_dataset
3+
from transformers import Llama4ForConditionalGeneration, Llama4Processor
4+
5+
from llmcompressor import oneshot
6+
from llmcompressor.modeling import prepare_for_calibration
7+
from llmcompressor.modifiers.quantization import QuantizationModifier
8+
9+
# Select model and load it.
10+
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
11+
model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
12+
processor = Llama4Processor.from_pretrained(model_id)
13+
# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`.
14+
# This change allows compatibility with vllm.
15+
# To apply your own custom module for experimentation, consider updating
16+
# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
17+
model = prepare_for_calibration(model)
18+
19+
DATASET_ID = "neuralmagic/calibration"
20+
NUM_CALIBRATION_SAMPLES = 20
21+
MAX_SEQUENCE_LENGTH = 8192
22+
23+
ds = load_dataset(DATASET_ID, name="LLM", split=f"train[:{NUM_CALIBRATION_SAMPLES}]")
24+
25+
26+
def preprocess_function(example):
27+
messgages = []
28+
for message in example["messages"]:
29+
messgages.append(
30+
{
31+
"role": message["role"],
32+
"content": [{"type": "text", "text": message["content"]}],
33+
}
34+
)
35+
36+
return processor.apply_chat_template(
37+
messgages,
38+
return_tensors="pt",
39+
padding=False,
40+
truncation=True,
41+
max_length=MAX_SEQUENCE_LENGTH,
42+
tokenize=True,
43+
add_special_tokens=False,
44+
return_dict=True,
45+
add_generation_prompt=False,
46+
)
47+
48+
49+
ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names)
50+
51+
52+
def data_collator(batch):
53+
assert len(batch) == 1
54+
return {
55+
key: torch.tensor(value)
56+
if key != "pixel_values"
57+
else torch.tensor(value, dtype=torch.bfloat16).squeeze(0)
58+
for key, value in batch[0].items()
59+
}
60+
61+
62+
# Configure the quantization algorithm to run.
63+
recipe = QuantizationModifier(
64+
targets="Linear",
65+
scheme="NVFP4",
66+
ignore=[
67+
"re:.*lm_head",
68+
"re:.*self_attn",
69+
"re:.*router",
70+
"re:vision_model.*",
71+
"re:multi_modal_projector.*",
72+
"Llama4TextAttention",
73+
],
74+
)
75+
76+
# Apply algorithms.
77+
# due to the large size of Llama4, we specify sequential targets such that
78+
# only one MLP is loaded into GPU memory at a time
79+
oneshot(
80+
model=model,
81+
dataset=ds,
82+
recipe=recipe,
83+
max_seq_length=MAX_SEQUENCE_LENGTH,
84+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
85+
sequential_targets=["Llama4TextMLP"],
86+
data_collator=data_collator,
87+
)
88+
89+
90+
# Save to disk compressed.
91+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-NVFP4"
92+
model.save_pretrained(SAVE_DIR)
93+
processor.save_pretrained(SAVE_DIR)

src/llmcompressor/modeling/deepseek_v3.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
import torch
2+
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
23
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
34

5+
__all__ = ["DeepseekV3MoECalibrate"]
6+
47

58
class DeepseekV3MoECalibrate(torch.nn.Module):
69
"""
710
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
811
"""
912

10-
def __init__(self, config, experts, gate, shared_experts):
13+
def __init__(self, config: DeepseekV3Config, original: DeepseekV3MoE):
1114
super().__init__()
1215
self.config = config
13-
self.experts = experts
14-
self.gate = gate
15-
self.shared_experts = shared_experts
16+
self.experts = original.experts
17+
self.gate = original.gate
18+
self.shared_experts = original.shared_experts
1619

17-
def forward(self, hidden_states):
20+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1821
residuals = hidden_states
1922
orig_shape = hidden_states.shape
2023
topk_indices, topk_weights = self.gate(hidden_states)
@@ -46,7 +49,5 @@ def forward(self, hidden_states):
4649
return hidden_states
4750

4851

49-
def replace(module: DeepseekV3MoE) -> DeepseekV3MoECalibrate:
50-
return DeepseekV3MoECalibrate(
51-
module.config, module.experts, module.gate, module.shared_experts
52-
)
52+
def replace(config: DeepseekV3Config, module: DeepseekV3MoE):
53+
return DeepseekV3MoECalibrate(config=config, original=module)

src/llmcompressor/modeling/llama4.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Tuple
2+
3+
import torch
4+
from transformers.models import Llama4Config
5+
from transformers.models.llama4.configuration_llama4 import Llama4TextConfig
6+
from transformers.models.llama4.modeling_llama4 import (
7+
Llama4TextExperts,
8+
Llama4TextMLP,
9+
Llama4TextMoe,
10+
)
11+
12+
from llmcompressor.utils.dev import skip_weights_initialize
13+
14+
__all__ = ["SequentialLlama4TextMoe"]
15+
16+
17+
class SequentialLlama4TextMoe(torch.nn.Module):
18+
def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):
19+
super().__init__()
20+
self.top_k = config.num_experts_per_tok
21+
self.hidden_dim = config.hidden_size
22+
self.num_experts = config.num_local_experts
23+
self.experts = SequentialLlama4TextExperts(config, original.experts)
24+
self.router = original.router
25+
self.shared_expert = original.shared_expert
26+
27+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]:
28+
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
29+
router_logits = self.router(hidden_states)
30+
31+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
32+
33+
router_scores = (
34+
torch.full_like(router_logits, float("-inf"))
35+
.scatter_(1, router_indices, router_top_value)
36+
.transpose(0, 1)
37+
)
38+
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
39+
40+
out = self.shared_expert(hidden_states)
41+
for i in range(self.num_experts):
42+
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)
43+
44+
return out, router_scores
45+
46+
47+
class SequentialLlama4TextExperts(torch.nn.ModuleList):
48+
def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts):
49+
self.num_experts = original.gate_up_proj.shape[0]
50+
with skip_weights_initialize():
51+
super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)])
52+
53+
intermediate_size = original.down_proj.shape[1]
54+
55+
for i in range(self.num_experts):
56+
gate_up = original.gate_up_proj[i]
57+
down = original.down_proj[i]
58+
59+
gate_proj = gate_up[:, :intermediate_size]
60+
up_proj = gate_up[:, intermediate_size:]
61+
62+
self[i].gate_proj.weight.data = gate_proj.t().clone().contiguous()
63+
self[i].up_proj.weight.data = up_proj.t().clone().contiguous()
64+
self[i].down_proj.weight.data = down.t().clone().contiguous()
65+
66+
67+
def replace(config: Llama4Config, module: Llama4TextMoe):
68+
return SequentialLlama4TextMoe(config=config.get_text_config(), original=module)

src/llmcompressor/modeling/prepare.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
from compressed_tensors.utils import replace_module
22
from transformers import PreTrainedModel
33

4-
from llmcompressor.modeling.deepseek_v3 import replace as replace_DeepseekV3MoE
4+
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
5+
from llmcompressor.modeling.llama4 import replace as replace_llama4
56

67
__all__ = ["prepare_for_calibration"]
78

89
replacements = {
9-
"DeepseekV3MoE": replace_DeepseekV3MoE,
10+
"DeepseekV3MoE": replace_deepseekv3,
11+
"Llama4TextMoe": replace_llama4,
1012
}
1113

1214

1315
def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
1416
for name, module in model.named_modules():
1517
cls_name = module.__class__.__name__
1618
if cls_name in replacements:
17-
new_module = replacements[cls_name](module)
19+
new_module = replacements[cls_name](config=model.config, module=module)
1820
replace_module(model, name, new_module)
1921

2022
return model

0 commit comments

Comments
 (0)