Skip to content

Commit 945007c

Browse files
committed
fix typing
1 parent ea93089 commit 945007c

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/llmcompressor/modeling/deepseek_v3.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import torch
22
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
3-
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
4-
5-
__all__ = ["DeepseekV3MoECalibrate"]
6-
3+
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
4+
DeepseekV3MoE as OriginalDeepseekV3MoE
5+
)
76

87
class DeepseekV3MoE(torch.nn.Module):
98
"""
109
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
1110
"""
1211

13-
def __init__(self, config: DeepseekV3Config, original: DeepseekV3MoE):
12+
def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE):
1413
super().__init__()
1514
self.config = config
1615
self.experts = original.experts
@@ -49,5 +48,5 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
4948
return hidden_states
5049

5150

52-
def replace(config: DeepseekV3Config, module: DeepseekV3MoE):
51+
def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE):
5352
return DeepseekV3MoECalibrate(config=config, original=module)

src/llmcompressor/modeling/prepare.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def update_qwen3_moe(model, stack):
3232
for module in model.modules():
3333
cls_name = module.__class__.__name__
3434
if cls_name == "Qwen3MoeDecoderLayer":
35+
# Optionally update the model.config to pass in other arguments
3536
stack.enter_context(
3637
patch_attr(module, "mlp", replace_Qwen3MoE(model.config, module.mlp))
3738
)

0 commit comments

Comments
 (0)