Skip to content

Commit 528cdc8

Browse files
committed
rebase; fix
1 parent 945007c commit 528cdc8

File tree

5 files changed

+22
-27
lines changed

5 files changed

+22
-27
lines changed

examples/quantization_w4a4_fp4/qwen_30b_a2b.py renamed to examples/quantization_w4a4_fp4/qwen_30b_a3b.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
1616
DATASET_SPLIT = "train_sft"
1717

18-
# Select number of samples. 512 samples is a good place to start.
19-
# Increasing the number of samples can improve accuracy.
18+
# Select number of samples
2019
NUM_CALIBRATION_SAMPLES = 20
2120
MAX_SEQUENCE_LENGTH = 2048
2221

src/llmcompressor/modeling/deepseek_v3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import torch
22
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
33
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
4-
DeepseekV3MoE as OriginalDeepseekV3MoE
4+
DeepseekV3MoE as OriginalDeepseekV3MoE,
55
)
66

7-
class DeepseekV3MoE(torch.nn.Module):
7+
8+
class DeepseekV3MoECalibrate(torch.nn.Module):
89
"""
910
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
1011
"""

src/llmcompressor/modeling/llama4.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111

1212
from llmcompressor.utils.dev import skip_weights_initialize
1313

14-
__all__ = ["SequentialLlama4TextMoe"]
15-
1614

1715
class SequentialLlama4TextMoe(torch.nn.Module):
1816
def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):

src/llmcompressor/modeling/prepare.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,19 @@ def update_qwen3_moe(model, stack):
3434
if cls_name == "Qwen3MoeDecoderLayer":
3535
# Optionally update the model.config to pass in other arguments
3636
stack.enter_context(
37-
patch_attr(module, "mlp", replace_Qwen3MoE(model.config, module.mlp))
38-
)
39-
40-
41-
def update_deepseek3_moe(model, stack):
42-
for module in model.modules():
43-
cls_name = module.__class__.__name__
44-
if (
45-
cls_name == "DeepseekV3DecoderLayer"
46-
and module.mlp.__class__.__name__ == "DeepseekV3MoE"
47-
):
48-
stack.enter_context(
49-
patch_attr(module, "mlp", replace_DeepseekV3MoE(module.mlp))
37+
patch_attr(
38+
module,
39+
"mlp",
40+
replace_Qwen3MoE(config=model.config, module=module.mlp),
41+
)
5042
)
5143

5244

5345
moe_context = {
5446
"Qwen3MoeForCausalLM": update_qwen3_moe,
55-
# "DeepseekV3ForCausalLM": update_deepseek3_moe, TODO: uncomment when tested
5647
}
5748

49+
5850
def moe_calibration_context(model: PreTrainedModel, stack):
5951
# Temporarily updates the MoE modules within the context
6052
# Once the context exists, parameter updates persist

src/llmcompressor/modeling/qwen3_moe.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,26 @@
1515
# limitations under the License.
1616

1717
import torch
18+
from transformers.models import Qwen3MoeConfig
19+
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
20+
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock,
21+
)
1822

1923

2024
class Qwen3MoeSparseMoeBlock(torch.nn.Module):
21-
def __init__(self, config, gate, experts):
25+
def __init__(
26+
self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock
27+
):
2228
super().__init__()
2329
self.num_experts = config.num_experts
24-
self.top_k = config.num_experts
30+
self.top_k = config.top_k
2531
self.norm_topk_prob = config.norm_topk_prob
2632

2733
# gating
28-
self.gate = gate
29-
self.experts = experts
34+
self.gate = original.gate
35+
self.experts = original.experts
3036

3137
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
32-
""" """
3338
batch_size, sequence_length, hidden_dim = hidden_states.shape
3439
hidden_states = hidden_states.view(-1, hidden_dim)
3540
# router_logits: (batch * sequence_length, n_experts)
@@ -81,5 +86,5 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
8186
return final_hidden_states, router_logits
8287

8388

84-
def replace(config, module):
85-
return Qwen3MoeSparseMoeBlock(config, module.gate, module.experts)
89+
def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock):
90+
return Qwen3MoeSparseMoeBlock(config=config, original=module)

0 commit comments

Comments
 (0)