We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2c41df7 commit 394283eCopy full SHA for 394283e
src/llmcompressor/modifiers/awq/mappings.py
@@ -116,6 +116,21 @@ class AWQMapping:
116
),
117
]
118
119
+# DeepseekV3
120
+_deepseek_mappings = [
121
+ AWQMapping(
122
+ "re:.*input_layernorm$",
123
+ ["re:.*q_a_proj$", "re:.*kv_a_proj_with_mqa$"],
124
+ ),
125
+ AWQMapping("re:.*q_a_layernorm$", ["re:.*q_b_proj$"]),
126
+ AWQMapping("re:.*kv_a_layernorm$", ["re:.*kv_b_proj$"]),
127
128
+ "re:.*post_attention_layernorm$",
129
+ ["re:.*gate_proj$", "re:.*up_proj$"],
130
131
+ AWQMapping("re:.*up_proj$", ["re:.*down_proj$"]),
132
+]
133
+
134
AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = {
135
"CohereForCausalLM": _cohere_mappings,
136
"Cohere2ForCausalLM": _cohere_mappings,
@@ -131,6 +146,7 @@ class AWQMapping:
146
"Qwen2MoeForCausalLM": _moe_default_mappings,
147
"Qwen3ForCausalLM": _default_mappings,
148
"Qwen3MoeForCausalLM": _moe_default_mappings,
149
+ "DeepseekV3ForCausalLM": _deepseek_mappings,
150
}
151
152
0 commit comments