From e3021ef73c1fcd6047a2e46e642b36c108dfe8fe Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 25 Jun 2025 19:01:07 +0000 Subject: [PATCH 01/20] update --- .../quantization_w4a4_fp4/qwen_30b_a2b.py | 86 +++++++++++++++++++ src/llmcompressor/entrypoints/oneshot.py | 3 +- src/llmcompressor/modeling/prepare.py | 17 +++- .../pipelines/independent/pipeline.py | 3 +- .../pipelines/sequential/pipeline.py | 12 ++- 5 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 examples/quantization_w4a4_fp4/qwen_30b_a2b.py diff --git a/examples/quantization_w4a4_fp4/qwen_30b_a2b.py b/examples/quantization_w4a4_fp4/qwen_30b_a2b.py new file mode 100644 index 000000000..02694f532 --- /dev/null +++ b/examples/quantization_w4a4_fp4/qwen_30b_a2b.py @@ -0,0 +1,86 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.utils import dispatch_for_generation + +MODEL_ID = "Qwen/Qwen3-30B-A3B" + +# Load model. +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, torch_dtype="auto" +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + + +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 5 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp4 with per group 16 via ptq +# * calibrate a global_scale for activations, which will be used to +# quantize activations to fp4 on the fly +recipe = QuantizationModifier( + targets="Linear", scheme="NVFP4", ignore=["lm_head", "re:.*mlp.gate$"] +) + +# Apply quantization. +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + + +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 9219f21fb..db8980b9c 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -189,7 +189,8 @@ def apply_recipe_modifiers( user_pipeline = self.dataset_args.pipeline modifiers = session.lifecycle.recipe.modifiers pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline) - pipeline(self.model, calibration_dataloader, self.dataset_args) + # ToDo: wrap moe_calibrate_all_experts in some set of args + pipeline(self.model, calibration_dataloader, self.dataset_args, calibrate_moe_context=True) session.finalize() diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 0ef627db4..bd9673e51 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -3,6 +3,7 @@ from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 from llmcompressor.modeling.llama4 import replace as replace_llama4 +from llmcompressor.utils.helpers import patch_attr __all__ = ["prepare_for_calibration"] @@ -11,7 +12,6 @@ "Llama4TextMoe": replace_llama4, } - def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: for name, module in model.named_modules(): cls_name = module.__class__.__name__ @@ -20,3 +20,18 @@ def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: replace_module(model, name, new_module) return model + +def update_qwen3_moe(model, stack): + for module in model.model.layers: + stack.enter_context( + patch_attr(module.mlp, "top_k", model.config.num_experts) + ) + + +moe_context = { + "Qwen3MoeForCausalLM": update_qwen3_moe, +} + +def calibrate_moe_context(model: PreTrainedModel, stack): + cls_name = model.__class__.__name__ + moe_context.get(cls_name)(model, stack) diff --git a/src/llmcompressor/pipelines/independent/pipeline.py b/src/llmcompressor/pipelines/independent/pipeline.py index c204e012e..8e9eb0a35 100644 --- a/src/llmcompressor/pipelines/independent/pipeline.py +++ b/src/llmcompressor/pipelines/independent/pipeline.py @@ -21,6 +21,7 @@ def __call__( model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments", + calibrate_moe_context: bool = False ): """ Data pipeline where each modifier is assigned its own calibration epoch and data @@ -42,6 +43,6 @@ def __call__( pipeline_name = pipeline.__class__.__name__ _logger.info(f"Inferred `{pipeline_name}` for `{mod_type}`") - pipeline(model, dataloader, dataset_args) + pipeline(model, dataloader, dataset_args, calibrate_moe_context) # restore modifiers on exit so model can be compressed based on recipe diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index c9af8f0fc..48dbf7cdf 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -1,5 +1,5 @@ from typing import TYPE_CHECKING - +import contextlib import torch from compressed_tensors.utils import disable_offloading, get_execution_device from torch.utils.data.dataloader import DataLoader @@ -15,6 +15,7 @@ trace_subgraphs, ) from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context +from llmcompressor.modeling.prepare import calibrate_moe_context if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -26,7 +27,7 @@ class SequentialPipeline(CalibrationPipeline): @staticmethod def __call__( - model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments" + model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments", calibrate_moe_context: bool = False, ): """ Run a sequential data pipeline according to the following steps: @@ -69,6 +70,13 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(DisableQuantization(model)) + + if calibrate_moe_context: + stack.enter_context((calibrate_moe_context(model))) + with calibration_forward_context(model), DisableQuantization(model): # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) From 7310fed2fd797bd146d759fc98620f60e80b6230 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 25 Jun 2025 19:53:09 +0000 Subject: [PATCH 02/20] style; update name --- examples/quantization_w4a4_fp4/qwen_30b_a2b.py | 4 +--- src/llmcompressor/entrypoints/oneshot.py | 7 ++++++- src/llmcompressor/modeling/prepare.py | 8 ++++---- src/llmcompressor/pipelines/independent/pipeline.py | 2 +- src/llmcompressor/pipelines/sequential/pipeline.py | 12 ++++++++---- 5 files changed, 20 insertions(+), 13 deletions(-) diff --git a/examples/quantization_w4a4_fp4/qwen_30b_a2b.py b/examples/quantization_w4a4_fp4/qwen_30b_a2b.py index 02694f532..6ba63fb4d 100644 --- a/examples/quantization_w4a4_fp4/qwen_30b_a2b.py +++ b/examples/quantization_w4a4_fp4/qwen_30b_a2b.py @@ -8,9 +8,7 @@ MODEL_ID = "Qwen/Qwen3-30B-A3B" # Load model. -model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, torch_dtype="auto" -) +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index db8980b9c..60428ad91 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -190,7 +190,12 @@ def apply_recipe_modifiers( modifiers = session.lifecycle.recipe.modifiers pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline) # ToDo: wrap moe_calibrate_all_experts in some set of args - pipeline(self.model, calibration_dataloader, self.dataset_args, calibrate_moe_context=True) + pipeline( + self.model, + calibration_dataloader, + self.dataset_args, + calibrate_moe_context=True, + ) session.finalize() diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index bd9673e51..8f09e0e17 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -12,6 +12,7 @@ "Llama4TextMoe": replace_llama4, } + def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: for name, module in model.named_modules(): cls_name = module.__class__.__name__ @@ -21,17 +22,16 @@ def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: return model + def update_qwen3_moe(model, stack): for module in model.model.layers: - stack.enter_context( - patch_attr(module.mlp, "top_k", model.config.num_experts) - ) + stack.enter_context(patch_attr(module.mlp, "top_k", model.config.num_experts)) moe_context = { "Qwen3MoeForCausalLM": update_qwen3_moe, } -def calibrate_moe_context(model: PreTrainedModel, stack): +def moe_calibration_context(model: PreTrainedModel, stack): cls_name = model.__class__.__name__ moe_context.get(cls_name)(model, stack) diff --git a/src/llmcompressor/pipelines/independent/pipeline.py b/src/llmcompressor/pipelines/independent/pipeline.py index 8e9eb0a35..4ee19d33b 100644 --- a/src/llmcompressor/pipelines/independent/pipeline.py +++ b/src/llmcompressor/pipelines/independent/pipeline.py @@ -21,7 +21,7 @@ def __call__( model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments", - calibrate_moe_context: bool = False + calibrate_moe_context: bool = False, ): """ Data pipeline where each modifier is assigned its own calibration epoch and data diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 48dbf7cdf..d49f27460 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -1,11 +1,13 @@ -from typing import TYPE_CHECKING import contextlib +from typing import TYPE_CHECKING + import torch from compressed_tensors.utils import disable_offloading, get_execution_device from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from llmcompressor.core import LifecycleCallbacks, active_session +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.registry import CalibrationPipeline @@ -15,7 +17,6 @@ trace_subgraphs, ) from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context -from llmcompressor.modeling.prepare import calibrate_moe_context if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -27,7 +28,10 @@ class SequentialPipeline(CalibrationPipeline): @staticmethod def __call__( - model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments", calibrate_moe_context: bool = False, + model: torch.nn.Module, + dataloader: DataLoader, + dataset_args: "DatasetArguments", + calibrate_moe_context: bool = False, ): """ Run a sequential data pipeline according to the following steps: @@ -75,7 +79,7 @@ def __call__( stack.enter_context(DisableQuantization(model)) if calibrate_moe_context: - stack.enter_context((calibrate_moe_context(model))) + stack.enter_context((moe_calibration_context(model))) with calibration_forward_context(model), DisableQuantization(model): # prepare intermediates cache From fe6f316cc483d297fc9bc201a9165131e0514962 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 25 Jun 2025 20:53:59 +0000 Subject: [PATCH 03/20] fix --- src/llmcompressor/pipelines/sequential/pipeline.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index d49f27460..6ecc1ee50 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -79,9 +79,8 @@ def __call__( stack.enter_context(DisableQuantization(model)) if calibrate_moe_context: - stack.enter_context((moe_calibration_context(model))) + moe_calibration_context(model, stack) - with calibration_forward_context(model), DisableQuantization(model): # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) From 04324d85a2616838701f5e0064d336244845c6ce Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 2 Jul 2025 14:32:21 -0400 Subject: [PATCH 04/20] update --- src/llmcompressor/modeling/prepare.py | 7 ++- src/llmcompressor/modeling/qwen3_moe.py | 77 +++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 src/llmcompressor/modeling/qwen3_moe.py diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 8f09e0e17..b6e0e838b 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -3,6 +3,7 @@ from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 from llmcompressor.modeling.llama4 import replace as replace_llama4 +from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE from llmcompressor.utils.helpers import patch_attr __all__ = ["prepare_for_calibration"] @@ -24,8 +25,10 @@ def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: def update_qwen3_moe(model, stack): - for module in model.model.layers: - stack.enter_context(patch_attr(module.mlp, "top_k", model.config.num_experts)) + for _, module in model.named_modules(): + cls_name = module.__class__.__name__ + if cls_name == "Qwen3MoeDecoderLayer": + stack.enter_context(patch_attr(module, "mlp", replace_Qwen3MoE())) moe_context = { diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py new file mode 100644 index 000000000..80332a22e --- /dev/null +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -0,0 +1,77 @@ +import torch +from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeMLP + + +class Qwen3MoeSparseMoeBlock(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = torch.nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = torch.nn.ModuleList( + [ + Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(self.num_experts) + ] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = torch.functional.softmax( + router_logits, dim=1, dtype=torch.float + ) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hitted: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + return final_hidden_states, router_logits + + +def replace(module): + return Qwen3MoeSparseMoeBlock(module.config) From 423c94fa2588613a8e74a2c0e160c16a4490e64b Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 2 Jul 2025 22:15:41 -0400 Subject: [PATCH 05/20] update --- .../quantization_w4a4_fp4/qwen_30b_a2b.py | 5 ++++- .../quantizing_moe/deepseek_r1_example.py | 4 ++-- src/llmcompressor/args/dataset_arguments.py | 10 ++++++++++ src/llmcompressor/entrypoints/oneshot.py | 1 - src/llmcompressor/modeling/prepare.py | 11 ++++++---- src/llmcompressor/modeling/qwen3_moe.py | 20 ++++++------------- src/llmcompressor/pipelines/basic/pipeline.py | 9 ++++++++- .../pipelines/independent/pipeline.py | 3 +-- .../pipelines/layer_sequential/pipeline.py | 10 +++++++++- .../pipelines/sequential/pipeline.py | 2 +- 10 files changed, 48 insertions(+), 27 deletions(-) diff --git a/examples/quantization_w4a4_fp4/qwen_30b_a2b.py b/examples/quantization_w4a4_fp4/qwen_30b_a2b.py index 6ba63fb4d..a0b71b131 100644 --- a/examples/quantization_w4a4_fp4/qwen_30b_a2b.py +++ b/examples/quantization_w4a4_fp4/qwen_30b_a2b.py @@ -17,7 +17,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 5 +NUM_CALIBRATION_SAMPLES = 20 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -60,12 +60,15 @@ def tokenize(sample): ) # Apply quantization. +# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock` +# during calibration oneshot( model=model, dataset=ds, recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, + calibrate_moe_context=True, ) diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index 6cb6937e8..0944e7b79 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -1,7 +1,7 @@ from datasets import load_dataset from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -from llmcompressor.modeling import prepare_for_calibration +from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot @@ -20,7 +20,7 @@ model_id, torch_dtype="auto", config=config ) tokenizer = AutoTokenizer.from_pretrained(model_id) -model = prepare_for_calibration(model) +model = replace_modules_for_calibration(model) # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 3f1bdcf6a..4fef6efb8 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -117,6 +117,16 @@ class DatasetArguments(CustomDatasetArguments): default=512, metadata={"help": "Number of samples to use for one-shot calibration"}, ) + calibrate_moe_context: bool = field( + default=False, + metadata={ + "help": "If during calibration, the MoE context should be enabled " + "for the given model. This usually involves updating all MoE modules " + "in the model for the duration of calibration. See moe_context under " + "modeling/prepare.py for a list of supported MoE the updated module " + "definitions" + }, + ) shuffle_calibration_samples: Optional[bool] = field( default=True, metadata={ diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 60428ad91..e12463099 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -194,7 +194,6 @@ def apply_recipe_modifiers( self.model, calibration_dataloader, self.dataset_args, - calibrate_moe_context=True, ) session.finalize() diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index b6e0e838b..078245255 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -6,7 +6,7 @@ from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE from llmcompressor.utils.helpers import patch_attr -__all__ = ["prepare_for_calibration"] +__all__ = ["replace_modules_for_calibration"] replacements = { "DeepseekV3MoE": replace_deepseekv3, @@ -14,7 +14,7 @@ } -def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: +def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: for name, module in model.named_modules(): cls_name = module.__class__.__name__ if cls_name in replacements: @@ -28,7 +28,9 @@ def update_qwen3_moe(model, stack): for _, module in model.named_modules(): cls_name = module.__class__.__name__ if cls_name == "Qwen3MoeDecoderLayer": - stack.enter_context(patch_attr(module, "mlp", replace_Qwen3MoE())) + stack.enter_context( + patch_attr(module, "mlp", replace_Qwen3MoE(model.config, module.mlp)) + ) moe_context = { @@ -37,4 +39,5 @@ def update_qwen3_moe(model, stack): def moe_calibration_context(model: PreTrainedModel, stack): cls_name = model.__class__.__name__ - moe_context.get(cls_name)(model, stack) + if cls_name in moe_context: + moe_context.get(cls_name)(model, stack) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index 80332a22e..23c22ec6f 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -1,22 +1,16 @@ import torch -from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeMLP class Qwen3MoeSparseMoeBlock(torch.nn.Module): - def __init__(self, config): + def __init__(self, config, gate, experts): super().__init__() self.num_experts = config.num_experts self.top_k = config.num_experts self.norm_topk_prob = config.norm_topk_prob # gating - self.gate = torch.nn.Linear(config.hidden_size, config.num_experts, bias=False) - self.experts = torch.nn.ModuleList( - [ - Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) - for _ in range(self.num_experts) - ] - ) + self.gate = gate + self.experts = experts def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ @@ -25,7 +19,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - routing_weights = torch.functional.softmax( + routing_weights = torch.nn.functional.softmax( router_logits, dim=1, dtype=torch.float ) routing_weights, selected_experts = torch.topk( @@ -35,7 +29,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, @@ -48,7 +41,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: selected_experts, num_classes=self.num_experts ).permute(2, 1, 0) - # Loop over all available experts in the model and perform the computation on each expert expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hitted: expert_layer = self.experts[expert_idx] @@ -73,5 +65,5 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits -def replace(module): - return Qwen3MoeSparseMoeBlock(module.config) +def replace(config, module): + return Qwen3MoeSparseMoeBlock(config, module.gate, module.experts) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 605358ae9..bd87de655 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -1,3 +1,4 @@ +import contextlib from typing import TYPE_CHECKING, Union import torch @@ -6,6 +7,7 @@ from torch.utils.data.dataloader import DataLoader from llmcompressor.core import LifecycleCallbacks +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pytorch.utils.helpers import tensors_to_device @@ -42,7 +44,12 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() - with calibration_forward_context(model): + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + + if dataset_args.calibrate_moe_context: + moe_calibration_context(model, stack) + for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) batch = tensors_to_device(batch, model_device) diff --git a/src/llmcompressor/pipelines/independent/pipeline.py b/src/llmcompressor/pipelines/independent/pipeline.py index 4ee19d33b..c204e012e 100644 --- a/src/llmcompressor/pipelines/independent/pipeline.py +++ b/src/llmcompressor/pipelines/independent/pipeline.py @@ -21,7 +21,6 @@ def __call__( model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments", - calibrate_moe_context: bool = False, ): """ Data pipeline where each modifier is assigned its own calibration epoch and data @@ -43,6 +42,6 @@ def __call__( pipeline_name = pipeline.__class__.__name__ _logger.info(f"Inferred `{pipeline_name}` for `{mod_type}`") - pipeline(model, dataloader, dataset_args, calibrate_moe_context) + pipeline(model, dataloader, dataset_args) # restore modifiers on exit so model can be compressed based on recipe diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index b8e74a279..51734ed41 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -1,3 +1,4 @@ +import contextlib from typing import TYPE_CHECKING import torch @@ -6,6 +7,7 @@ from torch.utils.data.dataloader import DataLoader from llmcompressor.core import LifecycleCallbacks, active_session +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.layer_sequential.helpers import ( @@ -69,7 +71,13 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() - with calibration_forward_context(model), DisableQuantization(model): + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(DisableQuantization(model)) + + if dataset_args.calibrate_moe_context: + moe_calibration_context(model, stack) + # prepare intermediates cache intermediates: IntermediatesCache = capture_first_layer_intermediates( model, layers[0], dataloader, model_device diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 6ecc1ee50..2cab5c25e 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -78,7 +78,7 @@ def __call__( stack.enter_context(calibration_forward_context(model)) stack.enter_context(DisableQuantization(model)) - if calibrate_moe_context: + if dataset_args.calibrate_moe_context: moe_calibration_context(model, stack) # prepare intermediates cache From 6af3ffc8d9c0269b948b595df81f599828a18a41 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 2 Jul 2025 22:19:21 -0400 Subject: [PATCH 06/20] update entrypoint --- src/llmcompressor/entrypoints/oneshot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index e12463099..b049f6b3b 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -232,6 +232,7 @@ def oneshot( overwrite_cache: bool = False, preprocessing_num_workers: Optional[int] = None, min_tokens_per_module: Optional[float] = None, + calibrate_moe_context: bool = False, # Miscellaneous arguments output_dir: Optional[str] = None, log_dir: Optional[str] = "sparse_logs", From 1a3dd300d2739b92083e16e4c7d83bcfa43eed9a Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 2 Jul 2025 22:34:14 -0400 Subject: [PATCH 07/20] update --- src/llmcompressor/modeling/deepseek_v3.py | 2 +- src/llmcompressor/modeling/prepare.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index 60436cdc9..9f83f2e7f 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -5,7 +5,7 @@ __all__ = ["DeepseekV3MoECalibrate"] -class DeepseekV3MoECalibrate(torch.nn.Module): +class DeepseekV3MoE(torch.nn.Module): """ Patched DeepseekV3MoE which sends all tokens to all experts for calibration """ diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 078245255..fc531ab71 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -8,6 +8,7 @@ __all__ = ["replace_modules_for_calibration"] +# ---------------------- module replacements; permanent ------------------------- replacements = { "DeepseekV3MoE": replace_deepseekv3, "Llama4TextMoe": replace_llama4, @@ -24,6 +25,9 @@ def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: return model +# ------------------- module replacements; during calibration -------------------- + + def update_qwen3_moe(model, stack): for _, module in model.named_modules(): cls_name = module.__class__.__name__ @@ -33,8 +37,18 @@ def update_qwen3_moe(model, stack): ) +def update_deepseek3_moe(model, stack): + for _, module in model.named_modules(): + cls_name = module.__class__.__name__ + if cls_name == "DeepseekV3MoE": + stack.enter_context( # ToDo - verify + patch_attr(module, "mlp", replace_DeepseekV3MoE(module)) + ) + + moe_context = { "Qwen3MoeForCausalLM": update_qwen3_moe, + "DeepseekV3ForCausalLM": update_deepseek3_moe, } def moe_calibration_context(model: PreTrainedModel, stack): From 079c71fcc817e5840c332a5adc8d2ea1d2e35875 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 3 Jul 2025 11:27:55 -0400 Subject: [PATCH 08/20] clean-up --- src/llmcompressor/args/dataset_arguments.py | 4 ++-- src/llmcompressor/entrypoints/oneshot.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 4fef6efb8..e19850c80 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -123,8 +123,8 @@ class DatasetArguments(CustomDatasetArguments): "help": "If during calibration, the MoE context should be enabled " "for the given model. This usually involves updating all MoE modules " "in the model for the duration of calibration. See moe_context under " - "modeling/prepare.py for a list of supported MoE the updated module " - "definitions" + "modeling/prepare.py for a list of supported MoEs and their updated " + "module definitions" }, ) shuffle_calibration_samples: Optional[bool] = field( diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index b049f6b3b..eeef4932f 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -189,7 +189,6 @@ def apply_recipe_modifiers( user_pipeline = self.dataset_args.pipeline modifiers = session.lifecycle.recipe.modifiers pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline) - # ToDo: wrap moe_calibrate_all_experts in some set of args pipeline( self.model, calibration_dataloader, From aee670fa004f848ff20ae957f9412a9766213c0c Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 3 Jul 2025 16:11:15 +0000 Subject: [PATCH 09/20] update prepare --- src/llmcompressor/modeling/prepare.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index fc531ab71..8f9868c5b 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -40,9 +40,12 @@ def update_qwen3_moe(model, stack): def update_deepseek3_moe(model, stack): for _, module in model.named_modules(): cls_name = module.__class__.__name__ - if cls_name == "DeepseekV3MoE": - stack.enter_context( # ToDo - verify - patch_attr(module, "mlp", replace_DeepseekV3MoE(module)) + if ( + cls_name == "DeepseekV3DecoderLayer" + and module.mlp.__class__.__name__ == "DeepseekV3MoE" + ): + stack.enter_context( + patch_attr(module, "mlp", replace_DeepseekV3MoE(module.mlp)) ) From 3b9e2c247af30ee8d53b4376a25a990fa972ea26 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 3 Jul 2025 16:25:55 +0000 Subject: [PATCH 10/20] add comment --- src/llmcompressor/modeling/prepare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 8f9868c5b..f90846be3 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -51,7 +51,7 @@ def update_deepseek3_moe(model, stack): moe_context = { "Qwen3MoeForCausalLM": update_qwen3_moe, - "DeepseekV3ForCausalLM": update_deepseek3_moe, + # "DeepseekV3ForCausalLM": update_deepseek3_moe, TODO: uncomment when tested } def moe_calibration_context(model: PreTrainedModel, stack): From a7af9ca69e6d2855dc92ee345d3da0da7db14cc7 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 4 Jul 2025 10:47:46 -0400 Subject: [PATCH 11/20] fix check --- src/llmcompressor/pipelines/basic/pipeline.py | 2 +- src/llmcompressor/pipelines/sequential/pipeline.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index bd87de655..db4b15305 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -47,7 +47,7 @@ def __call__( with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) - if dataset_args.calibrate_moe_context: + if dataset_args is not None and dataset_args.calibrate_moe_context: moe_calibration_context(model, stack) for batch in tqdm.tqdm(dataloader, desc="Calibrating"): diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 2cab5c25e..901283252 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -31,7 +31,6 @@ def __call__( model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments", - calibrate_moe_context: bool = False, ): """ Run a sequential data pipeline according to the following steps: From ea930897fd9f7bea02d1baf070c844ead3620bb8 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 9 Jul 2025 18:25:53 +0000 Subject: [PATCH 12/20] PR comments --- src/llmcompressor/modeling/prepare.py | 6 ++++-- src/llmcompressor/modeling/qwen3_moe.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index f90846be3..67d44ee81 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -29,7 +29,7 @@ def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: def update_qwen3_moe(model, stack): - for _, module in model.named_modules(): + for module in model.modules(): cls_name = module.__class__.__name__ if cls_name == "Qwen3MoeDecoderLayer": stack.enter_context( @@ -38,7 +38,7 @@ def update_qwen3_moe(model, stack): def update_deepseek3_moe(model, stack): - for _, module in model.named_modules(): + for module in model.modules(): cls_name = module.__class__.__name__ if ( cls_name == "DeepseekV3DecoderLayer" @@ -55,6 +55,8 @@ def update_deepseek3_moe(model, stack): } def moe_calibration_context(model: PreTrainedModel, stack): + # Temporarily updates the MoE modules within the context + # Once the context exists, parameter updates persist cls_name = model.__class__.__name__ if cls_name in moe_context: moe_context.get(cls_name)(model, stack) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index 23c22ec6f..a573778d0 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch From 945007ce64a7d10cb9e774f42732da28e9b6cc2e Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 9 Jul 2025 19:35:19 +0000 Subject: [PATCH 13/20] fix typing --- src/llmcompressor/modeling/deepseek_v3.py | 11 +++++------ src/llmcompressor/modeling/prepare.py | 1 + 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index 9f83f2e7f..18a40ab15 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -1,16 +1,15 @@ import torch from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config -from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE - -__all__ = ["DeepseekV3MoECalibrate"] - +from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3MoE as OriginalDeepseekV3MoE +) class DeepseekV3MoE(torch.nn.Module): """ Patched DeepseekV3MoE which sends all tokens to all experts for calibration """ - def __init__(self, config: DeepseekV3Config, original: DeepseekV3MoE): + def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE): super().__init__() self.config = config self.experts = original.experts @@ -49,5 +48,5 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -def replace(config: DeepseekV3Config, module: DeepseekV3MoE): +def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE): return DeepseekV3MoECalibrate(config=config, original=module) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 67d44ee81..aa0bdab65 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -32,6 +32,7 @@ def update_qwen3_moe(model, stack): for module in model.modules(): cls_name = module.__class__.__name__ if cls_name == "Qwen3MoeDecoderLayer": + # Optionally update the model.config to pass in other arguments stack.enter_context( patch_attr(module, "mlp", replace_Qwen3MoE(model.config, module.mlp)) ) From 528cdc83350657fd7d31f62e9f5d9e07fac6bb54 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 14 Jul 2025 20:56:40 +0000 Subject: [PATCH 14/20] rebase; fix --- .../{qwen_30b_a2b.py => qwen_30b_a3b.py} | 3 +-- src/llmcompressor/modeling/deepseek_v3.py | 5 +++-- src/llmcompressor/modeling/llama4.py | 2 -- src/llmcompressor/modeling/prepare.py | 20 ++++++------------- src/llmcompressor/modeling/qwen3_moe.py | 19 +++++++++++------- 5 files changed, 22 insertions(+), 27 deletions(-) rename examples/quantization_w4a4_fp4/{qwen_30b_a2b.py => qwen_30b_a3b.py} (95%) diff --git a/examples/quantization_w4a4_fp4/qwen_30b_a2b.py b/examples/quantization_w4a4_fp4/qwen_30b_a3b.py similarity index 95% rename from examples/quantization_w4a4_fp4/qwen_30b_a2b.py rename to examples/quantization_w4a4_fp4/qwen_30b_a3b.py index a0b71b131..ef7bc73bf 100644 --- a/examples/quantization_w4a4_fp4/qwen_30b_a2b.py +++ b/examples/quantization_w4a4_fp4/qwen_30b_a3b.py @@ -15,8 +15,7 @@ DATASET_ID = "HuggingFaceH4/ultrachat_200k" DATASET_SPLIT = "train_sft" -# Select number of samples. 512 samples is a good place to start. -# Increasing the number of samples can improve accuracy. +# Select number of samples NUM_CALIBRATION_SAMPLES = 20 MAX_SEQUENCE_LENGTH = 2048 diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index 18a40ab15..287b343bd 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -1,10 +1,11 @@ import torch from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( - DeepseekV3MoE as OriginalDeepseekV3MoE + DeepseekV3MoE as OriginalDeepseekV3MoE, ) -class DeepseekV3MoE(torch.nn.Module): + +class DeepseekV3MoECalibrate(torch.nn.Module): """ Patched DeepseekV3MoE which sends all tokens to all experts for calibration """ diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py index 02e3dc8fc..1d98ca57b 100644 --- a/src/llmcompressor/modeling/llama4.py +++ b/src/llmcompressor/modeling/llama4.py @@ -11,8 +11,6 @@ from llmcompressor.utils.dev import skip_weights_initialize -__all__ = ["SequentialLlama4TextMoe"] - class SequentialLlama4TextMoe(torch.nn.Module): def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe): diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index aa0bdab65..cb61f5fad 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -34,27 +34,19 @@ def update_qwen3_moe(model, stack): if cls_name == "Qwen3MoeDecoderLayer": # Optionally update the model.config to pass in other arguments stack.enter_context( - patch_attr(module, "mlp", replace_Qwen3MoE(model.config, module.mlp)) - ) - - -def update_deepseek3_moe(model, stack): - for module in model.modules(): - cls_name = module.__class__.__name__ - if ( - cls_name == "DeepseekV3DecoderLayer" - and module.mlp.__class__.__name__ == "DeepseekV3MoE" - ): - stack.enter_context( - patch_attr(module, "mlp", replace_DeepseekV3MoE(module.mlp)) + patch_attr( + module, + "mlp", + replace_Qwen3MoE(config=model.config, module=module.mlp), + ) ) moe_context = { "Qwen3MoeForCausalLM": update_qwen3_moe, - # "DeepseekV3ForCausalLM": update_deepseek3_moe, TODO: uncomment when tested } + def moe_calibration_context(model: PreTrainedModel, stack): # Temporarily updates the MoE modules within the context # Once the context exists, parameter updates persist diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index a573778d0..15ab062b7 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -15,21 +15,26 @@ # limitations under the License. import torch +from transformers.models import Qwen3MoeConfig +from transformers.models.qwen3_moe.modeling_qwen3_moe import ( + Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock, +) class Qwen3MoeSparseMoeBlock(torch.nn.Module): - def __init__(self, config, gate, experts): + def __init__( + self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock + ): super().__init__() self.num_experts = config.num_experts - self.top_k = config.num_experts + self.top_k = config.top_k self.norm_topk_prob = config.norm_topk_prob # gating - self.gate = gate - self.experts = experts + self.gate = original.gate + self.experts = original.experts def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) @@ -81,5 +86,5 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits -def replace(config, module): - return Qwen3MoeSparseMoeBlock(config, module.gate, module.experts) +def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock): + return Qwen3MoeSparseMoeBlock(config=config, original=module) From d7039a19ef32c92cd1a27a177e6793bb3489e9af Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 14 Jul 2025 21:05:59 +0000 Subject: [PATCH 15/20] update --- examples/multimodal_vision/llama4_example.py | 4 ++-- examples/quantization_w4a4_fp4/llama4_example.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index a28d89ced..292fa8300 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -3,7 +3,7 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import prepare_for_calibration +from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. @@ -14,7 +14,7 @@ # This change allows compatibility with vllm. # To apply your own custom module for experimentation, consider updating # `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = prepare_for_calibration(model) +model = replace_modules_for_calibration(model) DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/examples/quantization_w4a4_fp4/llama4_example.py b/examples/quantization_w4a4_fp4/llama4_example.py index 7b56928e8..28b57dda9 100644 --- a/examples/quantization_w4a4_fp4/llama4_example.py +++ b/examples/quantization_w4a4_fp4/llama4_example.py @@ -3,7 +3,7 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import prepare_for_calibration +from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import QuantizationModifier # Select model and load it. @@ -14,7 +14,7 @@ # This change allows compatibility with vllm. # To apply your own custom module for experimentation, consider updating # `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = prepare_for_calibration(model) +model = replace_modules_for_calibration(model) DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 20 From 9c9618378b84fac8ed37f058a3c5fe2c939e2167 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 14 Jul 2025 23:52:32 +0000 Subject: [PATCH 16/20] update --- src/llmcompressor/modeling/qwen3_moe.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index 15ab062b7..85843b4ae 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -62,24 +62,23 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: selected_experts, num_classes=self.num_experts ).permute(2, 1, 0) - expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hitted: + for expert_idx in range(len(self.experts)): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + has_tokens = idx.numel() > 0 # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = ( - expert_layer(current_state) * routing_weights[top_x, idx, None] - ) - + expert_output = expert_layer(current_state) + current_hidden_states = expert_output * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. - final_hidden_states.index_add_( - 0, top_x, current_hidden_states.to(hidden_states.dtype) - ) + if has_tokens: + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) final_hidden_states = final_hidden_states.reshape( batch_size, sequence_length, hidden_dim ) From a5a42bd43abdb1c5d71f346cdb8585e75a553276 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 14 Jul 2025 20:19:19 -0400 Subject: [PATCH 17/20] Update qwen3_moe.py --- src/llmcompressor/modeling/qwen3_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index 85843b4ae..7f5ec7e22 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -27,7 +27,7 @@ def __init__( ): super().__init__() self.num_experts = config.num_experts - self.top_k = config.top_k + self.top_k = config.num_experts self.norm_topk_prob = config.norm_topk_prob # gating From 8fc840fdbac00a1e66143a0881c1af9c27e7261d Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 15 Jul 2025 00:42:37 +0000 Subject: [PATCH 18/20] fix --- src/llmcompressor/modeling/qwen3_moe.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index 7f5ec7e22..f9f436516 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -62,23 +62,26 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: selected_experts, num_classes=self.num_experts ).permute(2, 1, 0) - for expert_idx in range(len(self.experts)): + # Loop over all available experts in the model and perform the computation on each expert + expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hitted: expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - has_tokens = idx.numel() > 0 # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - expert_output = expert_layer(current_state) - current_hidden_states = expert_output * routing_weights[top_x, idx, None] + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) + # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. - if has_tokens: - final_hidden_states.index_add_( - 0, top_x, current_hidden_states.to(hidden_states.dtype) - ) + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = final_hidden_states.reshape( batch_size, sequence_length, hidden_dim ) From 732b2ea0d40da3d0cc83a41213d7a4665d237003 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 15 Jul 2025 00:52:53 +0000 Subject: [PATCH 19/20] quality --- src/llmcompressor/modeling/qwen3_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index f9f436516..188f3b270 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -62,7 +62,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: selected_experts, num_classes=self.num_experts ).permute(2, 1, 0) - # Loop over all available experts in the model and perform the computation on each expert expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hitted: expert_layer = self.experts[expert_idx] From 3e860d66d5e5cc303936c3f69ec96a1bd755623d Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 15 Jul 2025 10:05:01 -0400 Subject: [PATCH 20/20] Alternate moe calib (#1645) SUMMARY: "please provide a brief summary" TEST PLAN: "please outline how the changes were tested" --- examples/quantization_w4a4_fp4/qwen_30b_a3b.py | 2 +- src/llmcompressor/modeling/qwen3_moe.py | 12 ++++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/examples/quantization_w4a4_fp4/qwen_30b_a3b.py b/examples/quantization_w4a4_fp4/qwen_30b_a3b.py index ef7bc73bf..0ba85de30 100644 --- a/examples/quantization_w4a4_fp4/qwen_30b_a3b.py +++ b/examples/quantization_w4a4_fp4/qwen_30b_a3b.py @@ -16,7 +16,7 @@ DATASET_SPLIT = "train_sft" # Select number of samples -NUM_CALIBRATION_SAMPLES = 20 +NUM_CALIBRATION_SAMPLES = 200 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index 188f3b270..fcd5d9925 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -27,7 +27,7 @@ def __init__( ): super().__init__() self.num_experts = config.num_experts - self.top_k = config.num_experts + self.top_k = config.top_k self.norm_topk_prob = config.norm_topk_prob # gating @@ -62,19 +62,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: selected_experts, num_classes=self.num_experts ).permute(2, 1, 0) - expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hitted: + for expert_idx in range(len(self.experts)): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = ( - expert_layer(current_state) * routing_weights[top_x, idx, None] - ) - + expert_output = expert_layer(current_state) + current_hidden_states = expert_output * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(