Skip to content

Commit 810d66d

Browse files
authored
Deprecate iter_named_leaf_modules and iter_named_quantizable_modules (#1628)
## Purpose ## * Refactor module targeting to be cleaner and easier to maintain * Support skipping `TransformBase` modules ## Prerequisites ## * neuralmagic/compressed-tensors#381 ## Changes ## * Remove all uses of `iter_named_leaf_modules` and `iter_named_quantizable_modules` * Make `Observer` inherit from `InternalModule` ## Testing ## * https://github.com/neuralmagic/llm-compressor-testing/actions/runs/16123598340 ✅ --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 70f93d3 commit 810d66d

File tree

10 files changed

+57
-92
lines changed

10 files changed

+57
-92
lines changed

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
127127
f"Expected start to be None or -1, got {self.end}"
128128
)
129129

130-
if (
131-
not hasattr(state, 'data') or
132-
state.data.calib is None
133-
):
130+
if not hasattr(state, "data") or state.data.calib is None:
134131
raise ValueError(
135132
f"{self.__class__.__name__} requires a calibration dataset to be "
136133
"provided"

src/llmcompressor/observers/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Iterable, Optional, Tuple, Union
33

44
import torch
5+
from compressed_tensors import InternalModule
56
from compressed_tensors.quantization.quant_args import (
67
FP8_E4M3_DATA,
78
QuantizationArgs,
@@ -12,12 +13,11 @@
1213
from compressed_tensors.utils import safe_permute
1314
from loguru import logger
1415
from torch import FloatTensor, IntTensor, Tensor
15-
from torch.nn import Module
1616

1717
__all__ = ["Observer"]
1818

1919

20-
class Observer(Module, RegistryMixin):
20+
class Observer(InternalModule, RegistryMixin):
2121
"""
2222
Base Observer class to be subclassed for specific implementation.
2323
Subclasses should override `calculate_qparams` to return a scale, zero_point

src/llmcompressor/transformers/compression/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55
from accelerate.accelerator import get_state_dict_offloaded_model
6-
from compressed_tensors.quantization.utils import iter_named_leaf_modules, module_type
6+
from compressed_tensors.quantization.utils import module_type
77
from compressed_tensors.utils import align_module_device
88
from tqdm import tqdm
99

@@ -163,7 +163,7 @@ def _get_sparse_targets_ignore_dicts(
163163
exhaustive_targets = defaultdict(list)
164164
exhaustive_ignore = defaultdict(list)
165165

166-
for name, submodule in iter_named_leaf_modules(module):
166+
for name, submodule in module.named_modules():
167167
submodule_type = module_type(submodule)
168168
is_target = is_sparse_compression_target(
169169
module=submodule,

src/llmcompressor/transformers/compression/quantization_format.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from compressed_tensors.quantization.utils import (
77
is_model_quantized,
88
is_module_quantized,
9-
iter_named_leaf_modules,
109
)
1110

1211
__all__ = ["infer_quantization_format"]
@@ -107,7 +106,7 @@ def _get_unique_quant_args(model):
107106
"""
108107
quant_info_weight = []
109108
quant_info_inputs = []
110-
for _, submodule in iter_named_leaf_modules(model):
109+
for submodule in model.modules():
111110
if is_module_quantized(submodule):
112111
weight_scheme = submodule.quantization_scheme.weights
113112
input_scheme = submodule.quantization_scheme.input_activations

src/llmcompressor/transformers/compression/sparsity_metadata_config.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from compressed_tensors.quantization.utils import (
77
is_model_quantized,
88
is_module_quantized,
9-
iter_named_leaf_modules,
109
)
1110
from loguru import logger
1211
from torch import Tensor
@@ -208,33 +207,34 @@ def is_sparse24_bitmask_supported(
208207
QuantizationType.FLOAT.value,
209208
]
210209

211-
for _, submodule in iter_named_leaf_modules(model):
212-
if is_module_quantized(submodule):
213-
weight_scheme = submodule.quantization_scheme.weights
214-
input_scheme = submodule.quantization_scheme.input_activations
215-
216-
if weight_scheme and input_scheme:
217-
# weight and activation quantization
218-
# check schemes are supported
219-
for scheme in [weight_scheme, input_scheme]:
220-
scheme_supported = (
221-
scheme.num_bits == 8
222-
and scheme.type in supported_scheme_types
223-
)
224-
if not scheme_supported:
225-
logger.info(
226-
"Quantization scheme not supported,"
227-
" turning off sparse 24 compression."
228-
f" Invalid Scheme: {scheme}"
229-
)
230-
return False
231-
232-
elif weight_scheme or input_scheme:
233-
# weight only quantization
234-
logger.info(
235-
"Weight only quantization detected, "
236-
"turning off sparse 24 compression."
210+
for submodule in model.modules():
211+
if not is_module_quantized(submodule):
212+
continue
213+
214+
weight_scheme = submodule.quantization_scheme.weights
215+
input_scheme = submodule.quantization_scheme.input_activations
216+
217+
if weight_scheme and input_scheme:
218+
# weight and activation quantization
219+
# check schemes are supported
220+
for scheme in [weight_scheme, input_scheme]:
221+
scheme_supported = (
222+
scheme.num_bits == 8 and scheme.type in supported_scheme_types
237223
)
238-
return False
224+
if not scheme_supported:
225+
logger.info(
226+
"Quantization scheme not supported,"
227+
" turning off sparse 24 compression."
228+
f" Invalid Scheme: {scheme}"
229+
)
230+
return False
231+
232+
elif weight_scheme or input_scheme:
233+
# weight only quantization
234+
logger.info(
235+
"Weight only quantization detected, "
236+
"turning off sparse 24 compression."
237+
)
238+
return False
239239

240240
return True

src/llmcompressor/utils/pytorch/module.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
from typing import Dict, List, Optional, Tuple, Union
99

1010
import torch
11+
from compressed_tensors import InternalModule
1112
from compressed_tensors.quantization.utils import is_module_quantized
12-
from compressed_tensors.transform import TransformBase
1313
from torch.nn import Linear, Module, Parameter
1414
from torch.nn.modules.conv import _ConvNd
1515
from transformers import PreTrainedModel
1616

1717
from llmcompressor.core import ModelParameterizedLayer
18-
from llmcompressor.observers import Observer
1918
from llmcompressor.utils.fsdp.context import (
2019
fix_fsdp_module_name,
2120
summon_full_params_context,
@@ -161,18 +160,6 @@ def match_layers_params(
161160
return resolved
162161

163162

164-
def is_internal_module(module: Module) -> bool:
165-
"""
166-
llm-compressor adds additional modules to a model, like observers
167-
and transforms, as part of its normal operation
168-
169-
:param name: name of module
170-
:return: True if name indicates a module internally instantiated by
171-
llm-compressor, otherwise False
172-
"""
173-
return isinstance(module, (TransformBase, Observer))
174-
175-
176163
def get_layers(
177164
targets: Union[str, List[str]],
178165
module: Module,
@@ -197,7 +184,7 @@ def get_layers(
197184
layer_dict = {
198185
name: layer
199186
for name, layer in layer_dict.items()
200-
if not is_internal_module(layer)
187+
if not isinstance(layer, InternalModule)
201188
}
202189

203190
return layer_dict

tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def forward(x, y):
120120
wrapped_fn = wrapper._wrapper_fn_defs[0]
121121
arg_names = {arg.arg for arg in wrapped_fn.args.args}
122122

123-
assert arg_names == {"x", "y"}, (
124-
f"Expected arguments {{'x', 'y'}}, but got {arg_names}"
125-
)
123+
assert arg_names == {
124+
"x",
125+
"y",
126+
}, f"Expected arguments {{'x', 'y'}}, but got {arg_names}"

tests/llmcompressor/transformers/compression/test_run_compressed.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77
from compressed_tensors.linear.compressed_linear import CompressedLinear
8-
from compressed_tensors.quantization.utils import iter_named_leaf_modules
98
from parameterized import parameterized_class
109
from transformers import AutoModelForCausalLM, AutoTokenizer
1110
from transformers.utils.quantization_config import CompressedTensorsConfig
@@ -132,9 +131,7 @@ def setUpClass(cls):
132131

133132
def test_compressed_linear_modules_exist(self):
134133
compressed_linear_counts = 0
135-
for _, submodule in iter_named_leaf_modules(
136-
self.compressed_model,
137-
):
134+
for submodule in self.compressed_model.modules():
138135
if isinstance(submodule, CompressedLinear):
139136
compressed_linear_counts += 1
140137

tests/llmcompressor/transformers/compression/test_sparsity_metadata_config.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,6 @@ def mock_is_model_quantized(model):
3232
return model.is_quantized
3333

3434

35-
def mock_iter_named_leaf_modules(model):
36-
for name, module in model.named_modules():
37-
yield name, module
38-
39-
4035
# Mock model class
4136
class MockModel(Module):
4237
def __init__(
@@ -99,10 +94,6 @@ def setup_mocks(self, request):
9994
f"{SPARSITY_CONFIG_LOCATION}.is_model_quantized",
10095
side_effect=mock_is_model_quantized,
10196
),
102-
patch(
103-
f"{SPARSITY_CONFIG_LOCATION}.iter_named_leaf_modules",
104-
side_effect=mock_iter_named_leaf_modules,
105-
),
10697
]
10798
for patcher in patchers:
10899
patcher.start()

tests/llmcompressor/transformers/kv_cache/test_kv_cache.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33

44
import pytest
55
from accelerate import init_empty_weights
6-
from compressed_tensors.quantization.lifecycle import KVCacheScaleType
7-
from compressed_tensors.quantization.utils.helpers import iter_named_quantizable_modules
6+
from compressed_tensors.quantization import KVCacheScaleType, is_attention_module
87
from datasets import load_dataset
98
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
109
from transformers.utils.quantization_config import CompressedTensorsConfig
@@ -159,13 +158,11 @@ def test_kv_cache_model_state_dict_attr(oneshot_fixture, tmp_path):
159158
model = AutoModelForCausalLM.from_pretrained(str(output_dir))
160159

161160
counts = 0
162-
for name, submodule in iter_named_quantizable_modules(
163-
model, include_children=False, include_attn=True
164-
):
165-
counts += 1
166-
assert "self_attn" in name
167-
assert hasattr(submodule, KVCacheScaleType.VALUE.value)
168-
assert hasattr(submodule, KVCacheScaleType.KEY.value)
161+
for name, submodule in model.named_modules():
162+
if is_attention_module(submodule):
163+
counts += 1
164+
assert hasattr(submodule, KVCacheScaleType.VALUE.value)
165+
assert hasattr(submodule, KVCacheScaleType.KEY.value)
169166
assert counts > 0
170167

171168

@@ -200,13 +197,11 @@ def test_kv_cache_gptq_config_format(kv_cache_fixture, tmp_path):
200197
model = AutoModelForCausalLM.from_pretrained(output_dir)
201198

202199
counts = 0
203-
for name, submodule in iter_named_quantizable_modules(
204-
model, include_children=False, include_attn=True
205-
):
206-
counts += 1
207-
assert "self_attn" in name
208-
assert hasattr(submodule, KVCacheScaleType.VALUE.value)
209-
assert hasattr(submodule, KVCacheScaleType.KEY.value)
200+
for name, submodule in model.named_modules():
201+
if is_attention_module(submodule):
202+
counts += 1
203+
assert hasattr(submodule, KVCacheScaleType.VALUE.value)
204+
assert hasattr(submodule, KVCacheScaleType.KEY.value)
210205

211206
assert counts > 0
212207

@@ -246,12 +241,10 @@ def test_kv_cache_gptq_model_state_dict_attr(kv_cache_fixture, tmp_path):
246241
)
247242

248243
counts = 0
249-
for name, submodule in iter_named_quantizable_modules(
250-
model, include_children=False, include_attn=True
251-
):
252-
counts += 1
253-
assert "self_attn" in name
254-
assert hasattr(submodule, KVCacheScaleType.VALUE.value)
255-
assert hasattr(submodule, KVCacheScaleType.KEY.value)
244+
for name, submodule in model.named_modules():
245+
if is_attention_module(submodule):
246+
counts += 1
247+
assert hasattr(submodule, KVCacheScaleType.VALUE.value)
248+
assert hasattr(submodule, KVCacheScaleType.KEY.value)
256249

257250
assert counts > 0

0 commit comments

Comments
 (0)