2
2
from typing import Dict , List , Optional , Tuple , Union
3
3
4
4
import torch
5
- from compressed_tensors .quantization import disable_quantization
5
+ from compressed_tensors .quantization import (
6
+ disable_quantization ,
7
+ find_name_or_class_matches ,
8
+ )
6
9
from compressed_tensors .utils import (
7
10
align_module_device ,
8
11
get_execution_device ,
26
29
from llmcompressor .pipelines .cache import IntermediatesCache
27
30
from llmcompressor .utils .fsdp .helpers import get_fsdp_parent
28
31
from llmcompressor .utils .helpers import calibration_forward_context
29
- from llmcompressor .utils .pytorch .module import (
30
- get_layers ,
31
- get_matching_layer ,
32
- get_parent_by_name ,
33
- )
32
+ from llmcompressor .utils .pytorch .module import get_layer_by_name , get_layers
34
33
35
34
__all__ = ["AWQModifier" ]
36
35
@@ -307,77 +306,82 @@ def _set_resolved_mappings(self, model: Module) -> None:
307
306
repeat for model.layer.1 and so on
308
307
"""
309
308
resolved_mappings : list [ResolvedMapping ] = []
310
- num_skipped_oproj_mappings = 0
311
- for mapping in self .mappings :
312
- to_smooth_layers = get_layers (mapping .smooth_layer , model )
313
- for layer_name , smooth_layer in to_smooth_layers .items ():
314
- # always exclude `.weight_observer`, only want `.weight`
315
- if layer_name not in self .ignore and not layer_name .endswith (
316
- "_observer"
317
- ):
318
- balance_layers , balance_names = [], []
319
- for balance_suffix in mapping .balance_layers :
320
- # find the submodule that matches the activation layer
321
- balance_name , balance_layer = get_matching_layer (
322
- balance_suffix , layer_name , model
323
- )
324
- if not balance_layer :
325
- continue
309
+ for mapping_idx , mapping in enumerate (self .mappings ):
310
+ smooth_layers = get_layers (mapping .smooth_layer , model )
311
+ smooth_names = [
312
+ smooth_name
313
+ for smooth_name in smooth_layers
314
+ if not find_name_or_class_matches (
315
+ smooth_name , model , self .ignore + ["re:.*_observer$" ]
316
+ )
317
+ ]
318
+
319
+ num_skipped_mappings = 0
320
+ pbar = tqdm (smooth_names )
321
+ for smooth_name in pbar :
322
+ pbar .set_description (
323
+ f"Resolving mapping { mapping_idx + 1 } /{ len (self .mappings )} "
324
+ f" ({ num_skipped_mappings } skipped)"
325
+ )
326
+ smooth_layer = smooth_layers [smooth_name ]
327
+
328
+ smooth_parent_name = "." .join (smooth_name .split ("." )[:- 1 ])
329
+ smooth_parent = get_layer_by_name (smooth_parent_name , model )
330
+
331
+ balance_layers , balance_names = [], []
332
+ for balance_regex in mapping .balance_layers :
333
+ # find the submodules that match the activation layer
334
+ for balance_suffix , balance_layer in get_layers (
335
+ balance_regex ,
336
+ smooth_parent ,
337
+ ).items ():
338
+ balance_name = f"{ smooth_parent_name } .{ balance_suffix } "
326
339
327
340
# exclude v_proj->o_proj mappings whose shapes are incompatible
328
341
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
329
342
if (
330
343
isinstance (smooth_layer , torch .nn .Linear )
331
344
and isinstance (balance_layer , torch .nn .Linear )
332
- and ".o_proj" in balance_name
345
+ and balance_name . endswith ( ".o_proj" )
333
346
and (
334
347
(
335
- ".v_proj" in layer_name
348
+ smooth_name . endswith ( ".v_proj" )
336
349
and smooth_layer .out_features
337
350
!= balance_layer .in_features
338
351
)
339
352
or (
340
- ".qkv_proj" in layer_name
353
+ smooth_name . endswith ( ".qkv_proj" )
341
354
and smooth_layer .out_features
342
355
!= 3 * balance_layer .in_features
343
356
)
344
357
)
345
358
):
346
- num_skipped_oproj_mappings += 1
359
+ num_skipped_mappings += 1
347
360
continue
348
361
349
362
balance_layers .append (balance_layer )
350
363
balance_names .append (balance_name )
351
364
352
- if len (balance_layers ) == 0 :
353
- continue
354
-
355
- # each mapping can contain multiple layers to balance, but only
356
- # one layer to smooth
357
- if len (balance_layers ) == 1 :
358
- # for single balance layer, parent is the balance layer
359
- parent_name , parent = balance_name , balance_layer
360
- else :
361
- # for multiple balance layers,
362
- # parent of any balance layer is the parent
363
- parent_name , parent = get_parent_by_name (
364
- layer_name = balance_name , model = model
365
- )
366
- resolved_mappings .append (
367
- ResolvedMapping (
368
- layer_name ,
369
- smooth_layer ,
370
- balance_layers ,
371
- balance_names = balance_names ,
372
- parent = parent ,
373
- parent_name = parent_name ,
374
- )
365
+ if len (balance_layers ) == 0 :
366
+ continue
367
+
368
+ elif len (balance_layers ) == 1 :
369
+ # for single balance layer, parent is the balance layer
370
+ parent_name , parent = balance_name , balance_layer
371
+ else :
372
+ # for multiple balance layers, find lowest common parent
373
+ parent_name , parent = get_lowest_common_parent (balance_names , model )
374
+
375
+ resolved_mappings .append (
376
+ ResolvedMapping (
377
+ smooth_name ,
378
+ smooth_layer ,
379
+ balance_layers ,
380
+ balance_names = balance_names ,
381
+ parent = parent ,
382
+ parent_name = parent_name ,
375
383
)
376
- if num_skipped_oproj_mappings > 0 :
377
- logger .info (
378
- f"Excluded { num_skipped_oproj_mappings } from resolved "
379
- "mappings due to shape mismatch"
380
- )
384
+ )
381
385
self ._resolved_mappings = resolved_mappings
382
386
return
383
387
@@ -401,11 +405,9 @@ def cache_smooth_activations_hook(
401
405
args : Tuple [torch .Tensor , ...],
402
406
_output : torch .Tensor ,
403
407
):
404
- # Assume that first argument is the input
405
- inp = args [0 ].cpu ().detach ().squeeze ()
406
-
407
408
self ._smooth_activation_means [smooth_name ] = _accumulate_mean (
408
- inp ,
409
+ # Assume that first argument is the input
410
+ args [0 ].cpu ().detach ().squeeze (),
409
411
self ._smooth_activation_means .get (smooth_name , None ),
410
412
)
411
413
@@ -444,12 +446,14 @@ def _apply_smoothing(self, model: Module) -> None:
444
446
445
447
:param model: model to apply smoothing to
446
448
"""
447
- for mapping in tqdm (self ._resolved_mappings , desc = "Smoothing" ):
448
- # NOTE: When using SequentialPipeline, not all the mappings
449
- # will have cached activations in the segment being udpated
450
- if mapping .smooth_name not in self ._smooth_activation_means :
451
- continue
452
-
449
+ # NOTE: When using SequentialPipeline, not all the mappings
450
+ # will have cached activations in the segment being udpated
451
+ mappings_to_smooth = [
452
+ mapping
453
+ for mapping in self ._resolved_mappings
454
+ if mapping .smooth_name in self ._smooth_activation_means
455
+ ]
456
+ for mapping in tqdm (mappings_to_smooth , desc = "Smoothing" ):
453
457
smooth_layer = mapping .smooth_layer
454
458
balance_layers = mapping .balance_layers
455
459
parent_module = mapping .parent
@@ -473,10 +477,15 @@ def _apply_smoothing(self, model: Module) -> None:
473
477
# [STEP 3]: Compute output of module
474
478
# could cache from hook, rather than recomputing here
475
479
fp16_output = self ._run_samples (parent_module )
476
- fp16_output = fp16_output .clip (
477
- torch .finfo (fp16_output .dtype ).min ,
478
- torch .finfo (fp16_output .dtype ).max ,
479
- )
480
+ if fp16_output .numel () == 0 :
481
+ logger .info (
482
+ f"Skipping smooth_layer { mapping .smooth_name } , no activations "
483
+ "found to scale. This can occasionally occur in MoE models "
484
+ "when certain experts are not activated by calibration samples."
485
+ )
486
+ del self ._smooth_activation_means [mapping .smooth_name ]
487
+ continue
488
+
480
489
x_mean = self ._smooth_activation_means [mapping .smooth_name ][0 ]
481
490
482
491
# [STEP 4]: Compute loss
@@ -536,10 +545,15 @@ def smooth(module):
536
545
537
546
def _run_samples (self , module : Module ) -> torch .Tensor :
538
547
with align_module_device (module ):
548
+ outputs = [
549
+ module (** batch_kwargs )
550
+ for batch_kwargs in self ._parent_args_cache [module ]
551
+ ]
539
552
return torch .cat (
540
553
[
541
- module (** batch_kwargs )[0 ]
542
- for batch_kwargs in self ._parent_args_cache [module ]
554
+ # If Tuple, assume that first argument is the input
555
+ output [0 ] if isinstance (output , Tuple ) else output
556
+ for output in outputs
543
557
],
544
558
dim = 0 ,
545
559
)
@@ -736,3 +750,35 @@ def _accumulate_mean(
736
750
new_count = prev_count + num_added
737
751
738
752
return (prev_sum + sum_added ) / new_count , new_count
753
+
754
+
755
+ def get_lowest_common_parent (names : List [str ], module : Module ) -> Tuple [str , Module ]:
756
+ """
757
+ Given a list of names, returns the lowest-scope common parent.
758
+
759
+ NOTE: function excludes parents of type ModuleList, which don't play
760
+ nicely with hooks because their forward method is never directly
761
+ called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
762
+ are selected based on router output and their forward method is called.
763
+ https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233
764
+
765
+ Returns name of parent and pointer to parent module
766
+
767
+ Implementation is a small alteration of os.path.commonprefix
768
+ https://docs.python.org/3/library/os.path.html#os.path.commonprefix
769
+ """
770
+ s1 = min (names )
771
+ s2 = max (names )
772
+ parent_name = ""
773
+ for i , c in enumerate (s1 ):
774
+ if c != s2 [i ]:
775
+ parent_name = s1 [:i ].rstrip ("." )
776
+ break
777
+
778
+ while True :
779
+ if parent_name == "" :
780
+ return "" , module
781
+ parent = get_layer_by_name (parent_name , module )
782
+ if not isinstance (parent , torch .nn .ModuleList ):
783
+ return parent_name , parent
784
+ parent_name = "." .join (parent_name .split ("." )[:- 1 ])
0 commit comments