diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index a28d89ced..292fa8300 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -3,7 +3,7 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import prepare_for_calibration +from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. @@ -14,7 +14,7 @@ # This change allows compatibility with vllm. # To apply your own custom module for experimentation, consider updating # `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = prepare_for_calibration(model) +model = replace_modules_for_calibration(model) DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/examples/quantization_w4a4_fp4/llama4_example.py b/examples/quantization_w4a4_fp4/llama4_example.py index 7b56928e8..28b57dda9 100644 --- a/examples/quantization_w4a4_fp4/llama4_example.py +++ b/examples/quantization_w4a4_fp4/llama4_example.py @@ -3,7 +3,7 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import prepare_for_calibration +from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import QuantizationModifier # Select model and load it. @@ -14,7 +14,7 @@ # This change allows compatibility with vllm. # To apply your own custom module for experimentation, consider updating # `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = prepare_for_calibration(model) +model = replace_modules_for_calibration(model) DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 20 diff --git a/examples/quantization_w4a4_fp4/qwen_30b_a3b.py b/examples/quantization_w4a4_fp4/qwen_30b_a3b.py new file mode 100644 index 000000000..0ba85de30 --- /dev/null +++ b/examples/quantization_w4a4_fp4/qwen_30b_a3b.py @@ -0,0 +1,86 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.utils import dispatch_for_generation + +MODEL_ID = "Qwen/Qwen3-30B-A3B" + +# Load model. +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + + +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples +NUM_CALIBRATION_SAMPLES = 200 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp4 with per group 16 via ptq +# * calibrate a global_scale for activations, which will be used to +# quantize activations to fp4 on the fly +recipe = QuantizationModifier( + targets="Linear", scheme="NVFP4", ignore=["lm_head", "re:.*mlp.gate$"] +) + +# Apply quantization. +# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock` +# during calibration +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + calibrate_moe_context=True, +) + + +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index 6cb6937e8..0944e7b79 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -1,7 +1,7 @@ from datasets import load_dataset from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -from llmcompressor.modeling import prepare_for_calibration +from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot @@ -20,7 +20,7 @@ model_id, torch_dtype="auto", config=config ) tokenizer = AutoTokenizer.from_pretrained(model_id) -model = prepare_for_calibration(model) +model = replace_modules_for_calibration(model) # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 3f1bdcf6a..e19850c80 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -117,6 +117,16 @@ class DatasetArguments(CustomDatasetArguments): default=512, metadata={"help": "Number of samples to use for one-shot calibration"}, ) + calibrate_moe_context: bool = field( + default=False, + metadata={ + "help": "If during calibration, the MoE context should be enabled " + "for the given model. This usually involves updating all MoE modules " + "in the model for the duration of calibration. See moe_context under " + "modeling/prepare.py for a list of supported MoEs and their updated " + "module definitions" + }, + ) shuffle_calibration_samples: Optional[bool] = field( default=True, metadata={ diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 9219f21fb..eeef4932f 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -189,7 +189,11 @@ def apply_recipe_modifiers( user_pipeline = self.dataset_args.pipeline modifiers = session.lifecycle.recipe.modifiers pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline) - pipeline(self.model, calibration_dataloader, self.dataset_args) + pipeline( + self.model, + calibration_dataloader, + self.dataset_args, + ) session.finalize() @@ -227,6 +231,7 @@ def oneshot( overwrite_cache: bool = False, preprocessing_num_workers: Optional[int] = None, min_tokens_per_module: Optional[float] = None, + calibrate_moe_context: bool = False, # Miscellaneous arguments output_dir: Optional[str] = None, log_dir: Optional[str] = "sparse_logs", diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index 60436cdc9..287b343bd 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -1,8 +1,8 @@ import torch from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config -from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE - -__all__ = ["DeepseekV3MoECalibrate"] +from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3MoE as OriginalDeepseekV3MoE, +) class DeepseekV3MoECalibrate(torch.nn.Module): @@ -10,7 +10,7 @@ class DeepseekV3MoECalibrate(torch.nn.Module): Patched DeepseekV3MoE which sends all tokens to all experts for calibration """ - def __init__(self, config: DeepseekV3Config, original: DeepseekV3MoE): + def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE): super().__init__() self.config = config self.experts = original.experts @@ -49,5 +49,5 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -def replace(config: DeepseekV3Config, module: DeepseekV3MoE): +def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE): return DeepseekV3MoECalibrate(config=config, original=module) diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py index 02e3dc8fc..1d98ca57b 100644 --- a/src/llmcompressor/modeling/llama4.py +++ b/src/llmcompressor/modeling/llama4.py @@ -11,8 +11,6 @@ from llmcompressor.utils.dev import skip_weights_initialize -__all__ = ["SequentialLlama4TextMoe"] - class SequentialLlama4TextMoe(torch.nn.Module): def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe): diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 0ef627db4..cb61f5fad 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -3,16 +3,19 @@ from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 from llmcompressor.modeling.llama4 import replace as replace_llama4 +from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE +from llmcompressor.utils.helpers import patch_attr -__all__ = ["prepare_for_calibration"] +__all__ = ["replace_modules_for_calibration"] +# ---------------------- module replacements; permanent ------------------------- replacements = { "DeepseekV3MoE": replace_deepseekv3, "Llama4TextMoe": replace_llama4, } -def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: +def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: for name, module in model.named_modules(): cls_name = module.__class__.__name__ if cls_name in replacements: @@ -20,3 +23,33 @@ def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: replace_module(model, name, new_module) return model + + +# ------------------- module replacements; during calibration -------------------- + + +def update_qwen3_moe(model, stack): + for module in model.modules(): + cls_name = module.__class__.__name__ + if cls_name == "Qwen3MoeDecoderLayer": + # Optionally update the model.config to pass in other arguments + stack.enter_context( + patch_attr( + module, + "mlp", + replace_Qwen3MoE(config=model.config, module=module.mlp), + ) + ) + + +moe_context = { + "Qwen3MoeForCausalLM": update_qwen3_moe, +} + + +def moe_calibration_context(model: PreTrainedModel, stack): + # Temporarily updates the MoE modules within the context + # Once the context exists, parameter updates persist + cls_name = model.__class__.__name__ + if cls_name in moe_context: + moe_context.get(cls_name)(model, stack) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py new file mode 100644 index 000000000..fcd5d9925 --- /dev/null +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -0,0 +1,87 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers.models import Qwen3MoeConfig +from transformers.models.qwen3_moe.modeling_qwen3_moe import ( + Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock, +) + + +class Qwen3MoeSparseMoeBlock(torch.nn.Module): + def __init__( + self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock + ): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.top_k + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = original.gate + self.experts = original.experts + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = torch.nn.functional.softmax( + router_logits, dim=1, dtype=torch.float + ) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + + for expert_idx in range(len(self.experts)): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + expert_output = expert_layer(current_state) + current_hidden_states = expert_output * routing_weights[top_x, idx, None] + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + return final_hidden_states, router_logits + + +def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock): + return Qwen3MoeSparseMoeBlock(config=config, original=module) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 605358ae9..db4b15305 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -1,3 +1,4 @@ +import contextlib from typing import TYPE_CHECKING, Union import torch @@ -6,6 +7,7 @@ from torch.utils.data.dataloader import DataLoader from llmcompressor.core import LifecycleCallbacks +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pytorch.utils.helpers import tensors_to_device @@ -42,7 +44,12 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() - with calibration_forward_context(model): + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + + if dataset_args is not None and dataset_args.calibrate_moe_context: + moe_calibration_context(model, stack) + for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) batch = tensors_to_device(batch, model_device) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index b8e74a279..51734ed41 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -1,3 +1,4 @@ +import contextlib from typing import TYPE_CHECKING import torch @@ -6,6 +7,7 @@ from torch.utils.data.dataloader import DataLoader from llmcompressor.core import LifecycleCallbacks, active_session +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.layer_sequential.helpers import ( @@ -69,7 +71,13 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() - with calibration_forward_context(model), DisableQuantization(model): + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(DisableQuantization(model)) + + if dataset_args.calibrate_moe_context: + moe_calibration_context(model, stack) + # prepare intermediates cache intermediates: IntermediatesCache = capture_first_layer_intermediates( model, layers[0], dataloader, model_device diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index c9af8f0fc..901283252 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -1,3 +1,4 @@ +import contextlib from typing import TYPE_CHECKING import torch @@ -6,6 +7,7 @@ from tqdm import tqdm from llmcompressor.core import LifecycleCallbacks, active_session +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.registry import CalibrationPipeline @@ -26,7 +28,9 @@ class SequentialPipeline(CalibrationPipeline): @staticmethod def __call__( - model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments" + model: torch.nn.Module, + dataloader: DataLoader, + dataset_args: "DatasetArguments", ): """ Run a sequential data pipeline according to the following steps: @@ -69,7 +73,13 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() - with calibration_forward_context(model), DisableQuantization(model): + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(DisableQuantization(model)) + + if dataset_args.calibrate_moe_context: + moe_calibration_context(model, stack) + # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device)