Skip to content

[Calibration] Add MoE Calibration Context #1596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/multimodal_vision/llama4_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from transformers import Llama4ForConditionalGeneration, Llama4Processor

from llmcompressor import oneshot
from llmcompressor.modeling import prepare_for_calibration
from llmcompressor.modeling import replace_modules_for_calibration
from llmcompressor.modifiers.quantization import GPTQModifier

# Select model and load it.
Expand All @@ -14,7 +14,7 @@
# This change allows compatibility with vllm.
# To apply your own custom module for experimentation, consider updating
# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
model = prepare_for_calibration(model)
model = replace_modules_for_calibration(model)

DATASET_ID = "neuralmagic/calibration"
NUM_CALIBRATION_SAMPLES = 512
Expand Down
4 changes: 2 additions & 2 deletions examples/quantization_w4a4_fp4/llama4_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from transformers import Llama4ForConditionalGeneration, Llama4Processor

from llmcompressor import oneshot
from llmcompressor.modeling import prepare_for_calibration
from llmcompressor.modeling import replace_modules_for_calibration
from llmcompressor.modifiers.quantization import QuantizationModifier

# Select model and load it.
Expand All @@ -14,7 +14,7 @@
# This change allows compatibility with vllm.
# To apply your own custom module for experimentation, consider updating
# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
model = prepare_for_calibration(model)
model = replace_modules_for_calibration(model)

DATASET_ID = "neuralmagic/calibration"
NUM_CALIBRATION_SAMPLES = 20
Expand Down
86 changes: 86 additions & 0 deletions examples/quantization_w4a4_fp4/qwen_30b_a3b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation

MODEL_ID = "Qwen/Qwen3-30B-A3B"

# Load model.
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)


DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples
NUM_CALIBRATION_SAMPLES = 200
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp4 with per group 16 via ptq
# * calibrate a global_scale for activations, which will be used to
# quantize activations to fp4 on the fly
recipe = QuantizationModifier(
targets="Linear", scheme="NVFP4", ignore=["lm_head", "re:.*mlp.gate$"]
)

# Apply quantization.
# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock`
# during calibration
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
calibrate_moe_context=True,
)


print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")


# Save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
4 changes: 2 additions & 2 deletions examples/quantizing_moe/deepseek_r1_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modeling import prepare_for_calibration
from llmcompressor.modeling import replace_modules_for_calibration
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot

Expand All @@ -20,7 +20,7 @@
model_id, torch_dtype="auto", config=config
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = prepare_for_calibration(model)
model = replace_modules_for_calibration(model)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
Expand Down
10 changes: 10 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ class DatasetArguments(CustomDatasetArguments):
default=512,
metadata={"help": "Number of samples to use for one-shot calibration"},
)
calibrate_moe_context: bool = field(
default=False,
metadata={
"help": "If during calibration, the MoE context should be enabled "
"for the given model. This usually involves updating all MoE modules "
"in the model for the duration of calibration. See moe_context under "
"modeling/prepare.py for a list of supported MoEs and their updated "
"module definitions"
},
)
shuffle_calibration_samples: Optional[bool] = field(
default=True,
metadata={
Expand Down
7 changes: 6 additions & 1 deletion src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ def apply_recipe_modifiers(
user_pipeline = self.dataset_args.pipeline
modifiers = session.lifecycle.recipe.modifiers
pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline)
pipeline(self.model, calibration_dataloader, self.dataset_args)
pipeline(
self.model,
calibration_dataloader,
self.dataset_args,
)

session.finalize()

Expand Down Expand Up @@ -227,6 +231,7 @@ def oneshot(
overwrite_cache: bool = False,
preprocessing_num_workers: Optional[int] = None,
min_tokens_per_module: Optional[float] = None,
calibrate_moe_context: bool = False,
# Miscellaneous arguments
output_dir: Optional[str] = None,
log_dir: Optional[str] = "sparse_logs",
Expand Down
10 changes: 5 additions & 5 deletions src/llmcompressor/modeling/deepseek_v3.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import torch
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE

__all__ = ["DeepseekV3MoECalibrate"]
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
DeepseekV3MoE as OriginalDeepseekV3MoE,
)


class DeepseekV3MoECalibrate(torch.nn.Module):
"""
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
"""

def __init__(self, config: DeepseekV3Config, original: DeepseekV3MoE):
def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE):
super().__init__()
self.config = config
self.experts = original.experts
Expand Down Expand Up @@ -49,5 +49,5 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


def replace(config: DeepseekV3Config, module: DeepseekV3MoE):
def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE):
return DeepseekV3MoECalibrate(config=config, original=module)
2 changes: 0 additions & 2 deletions src/llmcompressor/modeling/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

from llmcompressor.utils.dev import skip_weights_initialize

__all__ = ["SequentialLlama4TextMoe"]


class SequentialLlama4TextMoe(torch.nn.Module):
def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):
Expand Down
37 changes: 35 additions & 2 deletions src/llmcompressor/modeling/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,53 @@

from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
from llmcompressor.modeling.llama4 import replace as replace_llama4
from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE
from llmcompressor.utils.helpers import patch_attr

__all__ = ["prepare_for_calibration"]
__all__ = ["replace_modules_for_calibration"]

# ---------------------- module replacements; permanent -------------------------
replacements = {
"DeepseekV3MoE": replace_deepseekv3,
"Llama4TextMoe": replace_llama4,
}


def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
for name, module in model.named_modules():
cls_name = module.__class__.__name__
if cls_name in replacements:
new_module = replacements[cls_name](config=model.config, module=module)
replace_module(model, name, new_module)

return model


# ------------------- module replacements; during calibration --------------------


def update_qwen3_moe(model, stack):
for module in model.modules():
cls_name = module.__class__.__name__
if cls_name == "Qwen3MoeDecoderLayer":
# Optionally update the model.config to pass in other arguments
stack.enter_context(
patch_attr(
module,
"mlp",
replace_Qwen3MoE(config=model.config, module=module.mlp),
)
)


moe_context = {
"Qwen3MoeForCausalLM": update_qwen3_moe,
}


def moe_calibration_context(model: PreTrainedModel, stack):
# Temporarily updates the MoE modules within the context
# Once the context exists, parameter updates persist
cls_name = model.__class__.__name__
if cls_name in moe_context:
moe_context.get(cls_name)(model, stack)
87 changes: 87 additions & 0 deletions src/llmcompressor/modeling/qwen3_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# coding=utf-8
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from transformers.models import Qwen3MoeConfig
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock,
)


class Qwen3MoeSparseMoeBlock(torch.nn.Module):
def __init__(
self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock
):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.top_k
self.norm_topk_prob = config.norm_topk_prob

# gating
self.gate = original.gate
self.experts = original.experts

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

routing_weights = torch.nn.functional.softmax(
router_logits, dim=1, dtype=torch.float
)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)

for expert_idx in range(len(self.experts)):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
expert_output = expert_layer(current_state)
current_hidden_states = expert_output * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)

final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim
)
return final_hidden_states, router_logits


def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock):
return Qwen3MoeSparseMoeBlock(config=config, original=module)
9 changes: 8 additions & 1 deletion src/llmcompressor/pipelines/basic/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
from typing import TYPE_CHECKING, Union

import torch
Expand All @@ -6,6 +7,7 @@
from torch.utils.data.dataloader import DataLoader

from llmcompressor.core import LifecycleCallbacks
from llmcompressor.modeling.prepare import moe_calibration_context
from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch
from llmcompressor.pipelines.registry import CalibrationPipeline
from llmcompressor.pytorch.utils.helpers import tensors_to_device
Expand Down Expand Up @@ -42,7 +44,12 @@ def __call__(

LifecycleCallbacks.calibration_epoch_start()

with calibration_forward_context(model):
with contextlib.ExitStack() as stack:
stack.enter_context(calibration_forward_context(model))

if dataset_args is not None and dataset_args.calibrate_moe_context:
moe_calibration_context(model, stack)

for batch in tqdm.tqdm(dataloader, desc="Calibrating"):
batch = apply_pad_mask_to_batch(batch)
batch = tensors_to_device(batch, model_device)
Expand Down
Loading