Skip to content

Commit 95e8bdc

Browse files
AWQ Cohere, Mistral & Gemma mappings (#1570)
SUMMARY: Add AWQ mappings for `CohereForCausalLM` models, which don't have a post_attention_layernorm and instead run MLP and self_attn computations in parallel, along with Gemma and Mistral3 mappings Resolves #1566 (in addition to changes landed since 0.5.1 release) Resolves #1587 TODOs: - [x] Don't land until after 0.6.0 release - [x] Validate full run on `CohereLabs/c4ai-command-r-plus` (ran to completion, failed on generate because model couldn't fit on a single H100) TEST PLAN: on this branch, AWQ ran on `CohereLabs/c4ai-command-r-plus` --------- Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent e279a96 commit 95e8bdc

File tree

3 files changed

+95
-11
lines changed

3 files changed

+95
-11
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,13 @@ def _set_resolved_mappings(self, model: Module) -> None:
304304
"""
305305
resolved_mappings: list[ResolvedMapping] = []
306306
for mapping_idx, mapping in enumerate(self.mappings):
307-
smooth_layers = get_layers(mapping.smooth_layer, model)
307+
smooth_layers = get_layers(
308+
mapping.smooth_layer, model, exclude_internal_modules=True
309+
)
308310
smooth_names = [
309311
smooth_name
310312
for smooth_name in smooth_layers
311-
if not find_name_or_class_matches(
312-
smooth_name, model, self.ignore + ["re:.*_observer$"]
313-
)
313+
if not find_name_or_class_matches(smooth_name, model, self.ignore)
314314
]
315315

316316
num_skipped_mappings = 0
@@ -331,6 +331,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
331331
for balance_suffix, balance_layer in get_layers(
332332
balance_regex,
333333
smooth_parent,
334+
exclude_internal_modules=True,
334335
).items():
335336
balance_name = f"{smooth_parent_name}.{balance_suffix}"
336337

src/llmcompressor/modifiers/awq/mappings.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,56 @@ class AWQMapping:
7474
),
7575
]
7676

77+
# Gemma includes a pre_feedforward_layernorm in between
78+
# post_attention_layernorm and the mlp down/gate proj layers
79+
# use that instead of post_attention_layernorm in 3rd mapping:
80+
_gemma_mappings = [
81+
AWQMapping(
82+
"re:.*input_layernorm$",
83+
["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"],
84+
),
85+
AWQMapping("re:.*v_proj$", ["re:.*o_proj$"]),
86+
AWQMapping(
87+
"re:.*pre_feedforward_layernorm$",
88+
["re:.*gate_proj$", "re:.*up_proj$"],
89+
),
90+
AWQMapping(
91+
"re:.*up_proj$",
92+
["re:.*down_proj$"],
93+
),
94+
]
95+
96+
97+
# Cohere architecture is similar to default, with a very fundamental difference.
98+
# The MLP block is executed in parallel to the attention. So the tensor goes
99+
# through input_layernorm and then from there it goes directly to the attention
100+
# module and to the MLP module.
101+
_cohere_mappings = [
102+
AWQMapping(
103+
"re:.*input_layernorm$",
104+
[
105+
"re:.*self_attn.q_proj$",
106+
"re:.*self_attn.k_proj$",
107+
"re:.*self_attn.v_proj$",
108+
"re:.*mlp.gate_proj$",
109+
"re:.*mlp.up_proj$",
110+
],
111+
),
112+
AWQMapping("re:.*v_proj$", ["re:.*o_proj$"]),
113+
AWQMapping(
114+
"re:.*up_proj$",
115+
["re:.*down_proj$"],
116+
),
117+
]
118+
77119
AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = {
120+
"CohereForCausalLM": _cohere_mappings,
121+
"Cohere2ForCausalLM": _cohere_mappings,
122+
"Gemma2ForCausalLM": _gemma_mappings,
123+
"Gemma3ForCausalLM": _gemma_mappings,
124+
"Gemma3ForConditionalGeneration": _gemma_mappings,
78125
"LlamaForCausalLM": _default_mappings,
126+
"Mistral3ForConditionalGeneration": _default_mappings,
79127
"MistralForCausalLM": _default_mappings,
80128
"Phi3ForCausalLM": _phi_mappings,
81129
"Phi3VForCausalLM": _phi_mappings,

src/llmcompressor/utils/pytorch/module.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99

1010
import torch
1111
from compressed_tensors.quantization.utils import is_module_quantized
12-
from packaging import version
12+
from compressed_tensors.transform import TransformBase
1313
from torch.nn import Linear, Module, Parameter
1414
from torch.nn.modules.conv import _ConvNd
1515
from transformers import PreTrainedModel
1616

1717
from llmcompressor.core import ModelParameterizedLayer
18+
from llmcompressor.observers import Observer
1819
from llmcompressor.utils.fsdp.context import (
1920
fix_fsdp_module_name,
2021
summon_full_params_context,
@@ -64,10 +65,6 @@
6465
"get_layer_by_name",
6566
]
6667

67-
68-
_PARSED_TORCH_VERSION = version.parse(torch.__version__)
69-
70-
7168
ALL_TARGET = "__ALL__"
7269
ALL_PRUNABLE_TARGET = "__ALL_PRUNABLE__"
7370
ALL_QUANTIZABLE_TARGET = "__ALL_QUANTIZABLE__"
@@ -164,8 +161,46 @@ def match_layers_params(
164161
return resolved
165162

166163

167-
def get_layers(targets: Union[str, List[str]], module: Module) -> Dict[str, Module]:
168-
return match_layers_params(targets, module)
164+
def is_internal_module(module: Module) -> bool:
165+
"""
166+
llm-compressor adds additional modules to a model, like observers
167+
and transforms, as part of its normal operation
168+
169+
:param name: name of module
170+
:return: True if name indicates a module internally instantiated by
171+
llm-compressor, otherwise False
172+
"""
173+
return isinstance(module, (TransformBase, Observer))
174+
175+
176+
def get_layers(
177+
targets: Union[str, List[str]],
178+
module: Module,
179+
exclude_internal_modules: bool = False,
180+
) -> Dict[str, Module]:
181+
"""
182+
Get layers (also known as submodules) of module based on targets
183+
184+
:param targets: names or regexes to search for
185+
Can be regex, e.g. "re:.*input_layernorm$" to find all layers
186+
in module whose names end in string "input_layernorm"
187+
:param module: Parent module in which to search for targets
188+
:param exclude_internal_modules: If True, don't include internal
189+
modules added by llm-compressor, e.g. Observers and Transforms.
190+
Defaults to False to maintain backward compatibility
191+
192+
:return: dict of {layer name -> module} of all layers in module
193+
that match targets
194+
"""
195+
layer_dict = match_layers_params(targets, module)
196+
if exclude_internal_modules:
197+
layer_dict = {
198+
name: layer
199+
for name, layer in layer_dict.items()
200+
if not is_internal_module(layer)
201+
}
202+
203+
return layer_dict
169204

170205

171206
def get_layer(target: str, module: Module) -> Tuple[str, Module]:

0 commit comments

Comments
 (0)