-
Notifications
You must be signed in to change notification settings - Fork 181
[Calibration] Add MoE Calibration Context #1596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+264
−20
Merged
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
e3021ef
update
dsikka 7310fed
style; update name
dsikka fe6f316
fix
dsikka 04324d8
update
dsikka 423c94f
update
dsikka 6af3ffc
update entrypoint
dsikka 1a3dd30
update
dsikka 079c71f
clean-up
dsikka aee670f
update prepare
dsikka 3b9e2c2
add comment
dsikka a7af9ca
fix check
dsikka ea93089
PR comments
dsikka 945007c
fix typing
dsikka 528cdc8
rebase; fix
dsikka d7039a1
update
dsikka 9c96183
update
dsikka a5a42bd
Update qwen3_moe.py
dsikka 8fc840f
fix
dsikka 732b2ea
quality
dsikka 3e860d6
Alternate moe calib (#1645)
dsikka 331cca3
Merge branch 'main' into provide_moe_calibration_mode
dsikka fc08e41
Merge branch 'main' into provide_moe_calibration_mode
dsikka 4a951b2
Merge branch 'main' into provide_moe_calibration_mode
dsikka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from llmcompressor import oneshot | ||
from llmcompressor.modifiers.quantization import QuantizationModifier | ||
from llmcompressor.utils import dispatch_for_generation | ||
|
||
MODEL_ID = "Qwen/Qwen3-30B-A3B" | ||
|
||
# Load model. | ||
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
||
|
||
DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
DATASET_SPLIT = "train_sft" | ||
|
||
# Select number of samples | ||
NUM_CALIBRATION_SAMPLES = 200 | ||
MAX_SEQUENCE_LENGTH = 2048 | ||
|
||
# Load dataset and preprocess. | ||
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
ds = ds.shuffle(seed=42) | ||
|
||
|
||
def preprocess(example): | ||
return { | ||
"text": tokenizer.apply_chat_template( | ||
example["messages"], | ||
tokenize=False, | ||
) | ||
} | ||
|
||
|
||
ds = ds.map(preprocess) | ||
|
||
|
||
# Tokenize inputs. | ||
def tokenize(sample): | ||
return tokenizer( | ||
sample["text"], | ||
padding=False, | ||
max_length=MAX_SEQUENCE_LENGTH, | ||
truncation=True, | ||
add_special_tokens=False, | ||
) | ||
|
||
|
||
ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
||
# Configure the quantization algorithm and scheme. | ||
# In this case, we: | ||
# * quantize the weights to fp4 with per group 16 via ptq | ||
# * calibrate a global_scale for activations, which will be used to | ||
# quantize activations to fp4 on the fly | ||
recipe = QuantizationModifier( | ||
targets="Linear", scheme="NVFP4", ignore=["lm_head", "re:.*mlp.gate$"] | ||
) | ||
|
||
# Apply quantization. | ||
# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock` | ||
# during calibration | ||
oneshot( | ||
model=model, | ||
dataset=ds, | ||
recipe=recipe, | ||
max_seq_length=MAX_SEQUENCE_LENGTH, | ||
num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
calibrate_moe_context=True, | ||
) | ||
|
||
|
||
print("\n\n") | ||
print("========== SAMPLE GENERATION ==============") | ||
dispatch_for_generation(model) | ||
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") | ||
output = model.generate(input_ids, max_new_tokens=100) | ||
print(tokenizer.decode(output[0])) | ||
print("==========================================\n\n") | ||
|
||
|
||
# Save to disk in compressed-tensors format. | ||
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4" | ||
model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
tokenizer.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# coding=utf-8 | ||
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. | ||
# All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
dsikka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from transformers.models import Qwen3MoeConfig | ||
from transformers.models.qwen3_moe.modeling_qwen3_moe import ( | ||
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock, | ||
) | ||
|
||
|
||
class Qwen3MoeSparseMoeBlock(torch.nn.Module): | ||
def __init__( | ||
self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock | ||
): | ||
super().__init__() | ||
self.num_experts = config.num_experts | ||
self.top_k = config.top_k | ||
self.norm_topk_prob = config.norm_topk_prob | ||
|
||
# gating | ||
self.gate = original.gate | ||
self.experts = original.experts | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
batch_size, sequence_length, hidden_dim = hidden_states.shape | ||
hidden_states = hidden_states.view(-1, hidden_dim) | ||
# router_logits: (batch * sequence_length, n_experts) | ||
router_logits = self.gate(hidden_states) | ||
|
||
routing_weights = torch.nn.functional.softmax( | ||
router_logits, dim=1, dtype=torch.float | ||
) | ||
routing_weights, selected_experts = torch.topk( | ||
routing_weights, self.top_k, dim=-1 | ||
) | ||
if self.norm_topk_prob: # only diff with mixtral sparse moe block! | ||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True) | ||
# we cast back to the input dtype | ||
routing_weights = routing_weights.to(hidden_states.dtype) | ||
final_hidden_states = torch.zeros( | ||
(batch_size * sequence_length, hidden_dim), | ||
dtype=hidden_states.dtype, | ||
device=hidden_states.device, | ||
) | ||
|
||
# One hot encode the selected experts to create an expert mask | ||
# this will be used to easily index which expert is going to be sollicitated | ||
expert_mask = torch.nn.functional.one_hot( | ||
selected_experts, num_classes=self.num_experts | ||
).permute(2, 1, 0) | ||
|
||
for expert_idx in range(len(self.experts)): | ||
expert_layer = self.experts[expert_idx] | ||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||
# Index the correct hidden states and compute the expert hidden state for | ||
# the current expert. We need to make sure to multiply the output hidden | ||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2) | ||
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) | ||
expert_output = expert_layer(current_state) | ||
current_hidden_states = expert_output * routing_weights[top_x, idx, None] | ||
# However `index_add_` only support torch tensors for indexing so we'll use | ||
# the `top_x` tensor here. | ||
final_hidden_states.index_add_( | ||
0, top_x, current_hidden_states.to(hidden_states.dtype) | ||
) | ||
|
||
final_hidden_states = final_hidden_states.reshape( | ||
batch_size, sequence_length, hidden_dim | ||
) | ||
return final_hidden_states, router_logits | ||
|
||
|
||
def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock): | ||
return Qwen3MoeSparseMoeBlock(config=config, original=module) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.