Skip to content

Commit ceffa64

Browse files
AWQModifier fast resolve mappings, better logging, MoE support (#1444)
SUMMARY: In AWQ, resolving mappings can take a while because it is traversing the entire model tree, rather than just the parent, to find the balance layers. This scopes the search to just the parent module. For MoE models, the previous implementation only found a single layer for each regex string provided in mappings. This updates that to find as many as it can, which is necessary for mappings like ```python AWQMapping( "re:.*post_attention_layernorm$", ["re:.*mlp.experts.*.gate_proj$", "re:.*mlp.experts.*.up_proj$"], ) ``` which have multiple gate_proj and up_proj layers, one for each expert. gsm8k results with `Qwen/Qwen3-30B-A3B` MoE model after AWQ W4A16 Symmetric: |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.3813|± |0.0134| | | |strict-match | 5|exact_match|↑ |0.8810|± |0.0089| TEST PLAN: - [x] working with `Qwen/Qwen3-30B-A3B` with [same set of mappings used in AutoAWQ](https://github.com/casper-hansen/AutoAWQ/blob/main/awq/models/qwen3_moe.py#L24). Example included in this PR in `examples/awq/qwen3_moe_example.py`. Ran successfully in ~2 hours on a single H100 with ~70GB of 80GB used (additional memory needed during saving) - [x] Same wikitext PPL (14.0814) as on `main` for `meta-llama/Llama-3.2-3B-Instruct` --------- Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent e7c8fab commit ceffa64

File tree

6 files changed

+336
-148
lines changed

6 files changed

+336
-148
lines changed

examples/awq/qwen3_moe_example.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.awq import AWQModifier
6+
7+
# Select model and load it.
8+
MODEL_ID = "Qwen/Qwen3-30B-A3B"
9+
10+
model = AutoModelForCausalLM.from_pretrained(
11+
MODEL_ID, device_map="auto", torch_dtype="auto"
12+
)
13+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
14+
15+
# Select calibration dataset.
16+
DATASET_ID = "mit-han-lab/pile-val-backup"
17+
DATASET_SPLIT = "validation"
18+
19+
# Select number of samples. 256 samples is a good place to start.
20+
# Increasing the number of samples can improve accuracy.
21+
NUM_CALIBRATION_SAMPLES = 256
22+
MAX_SEQUENCE_LENGTH = 512
23+
24+
# Load dataset and preprocess.
25+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
26+
ds = ds.shuffle(seed=42)
27+
28+
29+
def preprocess(example):
30+
return {
31+
"text": tokenizer.apply_chat_template(
32+
[{"role": "user", "content": example["text"]}],
33+
tokenize=False,
34+
)
35+
}
36+
37+
38+
ds = ds.map(preprocess)
39+
40+
41+
# Tokenize inputs.
42+
def tokenize(sample):
43+
return tokenizer(
44+
sample["text"],
45+
padding=False,
46+
max_length=MAX_SEQUENCE_LENGTH,
47+
truncation=True,
48+
add_special_tokens=False,
49+
)
50+
51+
52+
# Configure the quantization algorithm to run.
53+
# NOTE: vllm currently does not support asym MoE, using symmetric here
54+
recipe = [
55+
AWQModifier(
56+
ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"],
57+
scheme="W4A16",
58+
targets=["Linear"],
59+
),
60+
]
61+
62+
# Apply algorithms.
63+
oneshot(
64+
model=model,
65+
dataset=ds,
66+
recipe=recipe,
67+
max_seq_length=MAX_SEQUENCE_LENGTH,
68+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
69+
)
70+
71+
# Confirm generations of the quantized model look sane.
72+
print("\n\n")
73+
print("========== SAMPLE GENERATION ==============")
74+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
75+
output = model.generate(input_ids, max_new_tokens=100)
76+
print(tokenizer.decode(output[0]))
77+
print("==========================================\n\n")
78+
79+
# Save to disk compressed.
80+
SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-sym"
81+
model.save_pretrained(SAVE_DIR, save_compressed=True)
82+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modifiers/awq/base.py

Lines changed: 116 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
from typing import Dict, List, Optional, Tuple, Union
33

44
import torch
5-
from compressed_tensors.quantization import disable_quantization
5+
from compressed_tensors.quantization import (
6+
disable_quantization,
7+
find_name_or_class_matches,
8+
)
69
from compressed_tensors.utils import (
710
align_module_device,
811
get_execution_device,
@@ -26,11 +29,7 @@
2629
from llmcompressor.pipelines.cache import IntermediatesCache
2730
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
2831
from llmcompressor.utils.helpers import calibration_forward_context
29-
from llmcompressor.utils.pytorch.module import (
30-
get_layers,
31-
get_matching_layer,
32-
get_parent_by_name,
33-
)
32+
from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers
3433

3534
__all__ = ["AWQModifier"]
3635

@@ -307,77 +306,82 @@ def _set_resolved_mappings(self, model: Module) -> None:
307306
repeat for model.layer.1 and so on
308307
"""
309308
resolved_mappings: list[ResolvedMapping] = []
310-
num_skipped_oproj_mappings = 0
311-
for mapping in self.mappings:
312-
to_smooth_layers = get_layers(mapping.smooth_layer, model)
313-
for layer_name, smooth_layer in to_smooth_layers.items():
314-
# always exclude `.weight_observer`, only want `.weight`
315-
if layer_name not in self.ignore and not layer_name.endswith(
316-
"_observer"
317-
):
318-
balance_layers, balance_names = [], []
319-
for balance_suffix in mapping.balance_layers:
320-
# find the submodule that matches the activation layer
321-
balance_name, balance_layer = get_matching_layer(
322-
balance_suffix, layer_name, model
323-
)
324-
if not balance_layer:
325-
continue
309+
for mapping_idx, mapping in enumerate(self.mappings):
310+
smooth_layers = get_layers(mapping.smooth_layer, model)
311+
smooth_names = [
312+
smooth_name
313+
for smooth_name in smooth_layers
314+
if not find_name_or_class_matches(
315+
smooth_name, model, self.ignore + ["re:.*_observer$"]
316+
)
317+
]
318+
319+
num_skipped_mappings = 0
320+
pbar = tqdm(smooth_names)
321+
for smooth_name in pbar:
322+
pbar.set_description(
323+
f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
324+
f" ({num_skipped_mappings} skipped)"
325+
)
326+
smooth_layer = smooth_layers[smooth_name]
327+
328+
smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
329+
smooth_parent = get_layer_by_name(smooth_parent_name, model)
330+
331+
balance_layers, balance_names = [], []
332+
for balance_regex in mapping.balance_layers:
333+
# find the submodules that match the activation layer
334+
for balance_suffix, balance_layer in get_layers(
335+
balance_regex,
336+
smooth_parent,
337+
).items():
338+
balance_name = f"{smooth_parent_name}.{balance_suffix}"
326339

327340
# exclude v_proj->o_proj mappings whose shapes are incompatible
328341
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
329342
if (
330343
isinstance(smooth_layer, torch.nn.Linear)
331344
and isinstance(balance_layer, torch.nn.Linear)
332-
and ".o_proj" in balance_name
345+
and balance_name.endswith(".o_proj")
333346
and (
334347
(
335-
".v_proj" in layer_name
348+
smooth_name.endswith(".v_proj")
336349
and smooth_layer.out_features
337350
!= balance_layer.in_features
338351
)
339352
or (
340-
".qkv_proj" in layer_name
353+
smooth_name.endswith(".qkv_proj")
341354
and smooth_layer.out_features
342355
!= 3 * balance_layer.in_features
343356
)
344357
)
345358
):
346-
num_skipped_oproj_mappings += 1
359+
num_skipped_mappings += 1
347360
continue
348361

349362
balance_layers.append(balance_layer)
350363
balance_names.append(balance_name)
351364

352-
if len(balance_layers) == 0:
353-
continue
354-
355-
# each mapping can contain multiple layers to balance, but only
356-
# one layer to smooth
357-
if len(balance_layers) == 1:
358-
# for single balance layer, parent is the balance layer
359-
parent_name, parent = balance_name, balance_layer
360-
else:
361-
# for multiple balance layers,
362-
# parent of any balance layer is the parent
363-
parent_name, parent = get_parent_by_name(
364-
layer_name=balance_name, model=model
365-
)
366-
resolved_mappings.append(
367-
ResolvedMapping(
368-
layer_name,
369-
smooth_layer,
370-
balance_layers,
371-
balance_names=balance_names,
372-
parent=parent,
373-
parent_name=parent_name,
374-
)
365+
if len(balance_layers) == 0:
366+
continue
367+
368+
elif len(balance_layers) == 1:
369+
# for single balance layer, parent is the balance layer
370+
parent_name, parent = balance_name, balance_layer
371+
else:
372+
# for multiple balance layers, find lowest common parent
373+
parent_name, parent = get_lowest_common_parent(balance_names, model)
374+
375+
resolved_mappings.append(
376+
ResolvedMapping(
377+
smooth_name,
378+
smooth_layer,
379+
balance_layers,
380+
balance_names=balance_names,
381+
parent=parent,
382+
parent_name=parent_name,
375383
)
376-
if num_skipped_oproj_mappings > 0:
377-
logger.info(
378-
f"Excluded {num_skipped_oproj_mappings} from resolved "
379-
"mappings due to shape mismatch"
380-
)
384+
)
381385
self._resolved_mappings = resolved_mappings
382386
return
383387

@@ -401,11 +405,9 @@ def cache_smooth_activations_hook(
401405
args: Tuple[torch.Tensor, ...],
402406
_output: torch.Tensor,
403407
):
404-
# Assume that first argument is the input
405-
inp = args[0].cpu().detach().squeeze()
406-
407408
self._smooth_activation_means[smooth_name] = _accumulate_mean(
408-
inp,
409+
# Assume that first argument is the input
410+
args[0].cpu().detach().squeeze(),
409411
self._smooth_activation_means.get(smooth_name, None),
410412
)
411413

@@ -444,12 +446,14 @@ def _apply_smoothing(self, model: Module) -> None:
444446
445447
:param model: model to apply smoothing to
446448
"""
447-
for mapping in tqdm(self._resolved_mappings, desc="Smoothing"):
448-
# NOTE: When using SequentialPipeline, not all the mappings
449-
# will have cached activations in the segment being udpated
450-
if mapping.smooth_name not in self._smooth_activation_means:
451-
continue
452-
449+
# NOTE: When using SequentialPipeline, not all the mappings
450+
# will have cached activations in the segment being udpated
451+
mappings_to_smooth = [
452+
mapping
453+
for mapping in self._resolved_mappings
454+
if mapping.smooth_name in self._smooth_activation_means
455+
]
456+
for mapping in tqdm(mappings_to_smooth, desc="Smoothing"):
453457
smooth_layer = mapping.smooth_layer
454458
balance_layers = mapping.balance_layers
455459
parent_module = mapping.parent
@@ -473,10 +477,15 @@ def _apply_smoothing(self, model: Module) -> None:
473477
# [STEP 3]: Compute output of module
474478
# could cache from hook, rather than recomputing here
475479
fp16_output = self._run_samples(parent_module)
476-
fp16_output = fp16_output.clip(
477-
torch.finfo(fp16_output.dtype).min,
478-
torch.finfo(fp16_output.dtype).max,
479-
)
480+
if fp16_output.numel() == 0:
481+
logger.info(
482+
f"Skipping smooth_layer {mapping.smooth_name}, no activations "
483+
"found to scale. This can occasionally occur in MoE models "
484+
"when certain experts are not activated by calibration samples."
485+
)
486+
del self._smooth_activation_means[mapping.smooth_name]
487+
continue
488+
480489
x_mean = self._smooth_activation_means[mapping.smooth_name][0]
481490

482491
# [STEP 4]: Compute loss
@@ -536,10 +545,15 @@ def smooth(module):
536545

537546
def _run_samples(self, module: Module) -> torch.Tensor:
538547
with align_module_device(module):
548+
outputs = [
549+
module(**batch_kwargs)
550+
for batch_kwargs in self._parent_args_cache[module]
551+
]
539552
return torch.cat(
540553
[
541-
module(**batch_kwargs)[0]
542-
for batch_kwargs in self._parent_args_cache[module]
554+
# If Tuple, assume that first argument is the input
555+
output[0] if isinstance(output, Tuple) else output
556+
for output in outputs
543557
],
544558
dim=0,
545559
)
@@ -736,3 +750,35 @@ def _accumulate_mean(
736750
new_count = prev_count + num_added
737751

738752
return (prev_sum + sum_added) / new_count, new_count
753+
754+
755+
def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]:
756+
"""
757+
Given a list of names, returns the lowest-scope common parent.
758+
759+
NOTE: function excludes parents of type ModuleList, which don't play
760+
nicely with hooks because their forward method is never directly
761+
called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
762+
are selected based on router output and their forward method is called.
763+
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233
764+
765+
Returns name of parent and pointer to parent module
766+
767+
Implementation is a small alteration of os.path.commonprefix
768+
https://docs.python.org/3/library/os.path.html#os.path.commonprefix
769+
"""
770+
s1 = min(names)
771+
s2 = max(names)
772+
parent_name = ""
773+
for i, c in enumerate(s1):
774+
if c != s2[i]:
775+
parent_name = s1[:i].rstrip(".")
776+
break
777+
778+
while True:
779+
if parent_name == "":
780+
return "", module
781+
parent = get_layer_by_name(parent_name, module)
782+
if not isinstance(parent, torch.nn.ModuleList):
783+
return parent_name, parent
784+
parent_name = ".".join(parent_name.split(".")[:-1])

0 commit comments

Comments
 (0)