Skip to content

Commit 40fa3c5

Browse files
authored
Deprecate iter_named_leaf_modules and iter_named_quantizable_modules (#381)
* remove iter helper functions Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use internal module Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add docstring Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix import cycle Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * keep as deprecated Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * rename to Untargetable Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * Revert "rename to Untargetable" This reverts commit 9b23a62. --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent a37df54 commit 40fa3c5

File tree

8 files changed

+56
-58
lines changed

8 files changed

+56
-58
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,7 @@
4242
load_pretrained_quantization_parameters,
4343
)
4444
from compressed_tensors.quantization.lifecycle import expand_target_names
45-
from compressed_tensors.quantization.utils import (
46-
is_module_quantized,
47-
iter_named_leaf_modules,
48-
)
45+
from compressed_tensors.quantization.utils import is_module_quantized
4946
from compressed_tensors.utils import (
5047
align_module_device,
5148
delete_offload_parameter,
@@ -747,7 +744,7 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
747744
"""
748745
return {
749746
fix_fsdp_module_name(name): module.quantization_scheme
750-
for name, module in iter_named_leaf_modules(model)
747+
for name, module in model.named_modules()
751748
if is_module_quantized(module)
752749
}
753750

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@
3838
KV_CACHE_TARGETS,
3939
infer_quantization_status,
4040
is_kv_cache_quant_scheme,
41-
iter_named_leaf_modules,
42-
iter_named_quantizable_modules,
4341
)
4442
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
4543
from compressed_tensors.utils.offload import update_parameter_data
@@ -87,7 +85,7 @@ def load_pretrained_quantization_parameters(
8785
model_path = get_safetensors_folder(model_name_or_path)
8886
mapping = get_quantization_parameter_to_path_mapping(model_path)
8987

90-
for name, submodule in iter_named_leaf_modules(model):
88+
for name, submodule in model.named_modules():
9189
if not is_module_quantized(submodule):
9290
continue
9391
if submodule.quantization_scheme.input_activations is not None:
@@ -152,11 +150,7 @@ def apply_quantization_config(
152150
# list of submodules to ignore
153151
ignored_submodules = defaultdict(list)
154152
# mark appropriate layers for quantization by setting their quantization schemes
155-
for name, submodule in iter_named_quantizable_modules(
156-
model,
157-
include_children=True,
158-
include_attn=True,
159-
): # child modules and attention modules
153+
for name, submodule in model.named_modules():
160154
# potentially fix module name to remove FSDP wrapper prefix
161155
name = fix_fsdp_module_name(name)
162156
if matches := find_name_or_class_matches(name, submodule, config.ignore):
@@ -287,7 +281,7 @@ def expand_target_names(
287281
"""
288282
return {
289283
name
290-
for name, module in iter_named_leaf_modules(model)
284+
for name, module in model.named_modules()
291285
if is_target(name, module, targets, ignore)
292286
}
293287

@@ -328,6 +322,11 @@ def find_name_or_class_matches(
328322
2. matches on regex patterns
329323
3. matches on module names
330324
"""
325+
from compressed_tensors import InternalModule
326+
327+
if isinstance(module, InternalModule):
328+
return []
329+
331330
targets = sorted(targets, key=lambda x: ("re:" in x, x))
332331
if isinstance(targets, Iterable):
333332
matches = _find_matches(name, targets) + _find_matches(

src/compressed_tensors/quantization/quant_config.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
preset_name_to_scheme,
2323
)
2424
from compressed_tensors.quantization.utils import (
25-
calculate_compression_ratio,
2625
is_module_quantized,
27-
iter_named_quantizable_modules,
2826
module_type,
2927
parse_out_kv_cache_args,
3028
)
@@ -177,9 +175,7 @@ def from_pretrained(
177175
quantization_status = None
178176
ignore = {}
179177
quantization_type_names = set()
180-
for name, submodule in iter_named_quantizable_modules(
181-
model, include_children=True, include_attn=True
182-
):
178+
for name, submodule in model.named_modules():
183179
layer_type = module_type(submodule)
184180
if not is_module_quantized(submodule):
185181
if layer_type not in ignore:

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
QuantizationType,
2727
)
2828
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
29+
from compressed_tensors.utils import deprecated
2930
from torch import FloatTensor, IntTensor, Tensor
3031
from torch.nn import Module
3132
from tqdm import tqdm
@@ -36,7 +37,6 @@
3637
"is_module_quantized",
3738
"is_model_quantized",
3839
"module_type",
39-
"calculate_compression_ratio",
4040
"get_torch_bit_depth",
4141
"can_quantize",
4242
"parse_out_kv_cache_args",
@@ -276,12 +276,7 @@ def is_model_quantized(model: Module) -> bool:
276276
:param model: pytorch model
277277
:return: True if model is quantized, False otherwise
278278
"""
279-
280-
for _, submodule in iter_named_leaf_modules(model):
281-
if is_module_quantized(submodule):
282-
return True
283-
284-
return False
279+
return any(is_module_quantized(submodule) for submodule in model.modules())
285280

286281

287282
def module_type(module: Module) -> str:
@@ -294,6 +289,11 @@ def module_type(module: Module) -> str:
294289
return type(module).__name__
295290

296291

292+
@deprecated(
293+
message="This function will be removed in a future release. "
294+
"Please use `model.named_modules()` and filter by "
295+
"compressed_tensors.InternalModule if neceessary"
296+
)
297297
def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]:
298298
"""
299299
Yields modules that do not have any submodules except observers. The observers
@@ -320,6 +320,11 @@ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None
320320
yield name, submodule
321321

322322

323+
@deprecated(
324+
message="This function will be removed in a future release. "
325+
"Please use `model.named_modules()` and filter by "
326+
"compressed_tensors.InternalModule if neceessary"
327+
)
323328
def iter_named_quantizable_modules(
324329
model: Module,
325330
include_children: bool = True,
@@ -330,7 +335,6 @@ def iter_named_quantizable_modules(
330335
Yield name and submodule of
331336
- leaf modules, set by include_children
332337
- attention modyles, set by include_attn
333-
334338
:param model: model to get leaf modules of
335339
:param include_children: flag to get the leaf modules
336340
:param inlcude_attn: flag to get the attention modules
@@ -397,34 +401,6 @@ def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool:
397401
return bit_depth > quant_args.num_bits
398402

399403

400-
def calculate_compression_ratio(model: Module) -> float:
401-
"""
402-
Calculates the quantization compression ratio of a pytorch model, based on the
403-
number of bits needed to represent the total weights in compressed form. Does not
404-
take into account activation quantizatons.
405-
406-
:param model: pytorch module to calculate compression ratio for
407-
:return: compression ratio of the whole model
408-
"""
409-
total_compressed = 0.0
410-
total_uncompressed = 0.0
411-
for name, submodule in tqdm(
412-
iter_named_leaf_modules(model),
413-
desc="Calculating quantization compression ratio",
414-
):
415-
for parameter in model.parameters():
416-
uncompressed_bits = get_torch_bit_depth(parameter)
417-
compressed_bits = uncompressed_bits
418-
if is_module_quantized(submodule) and submodule.quantization_scheme.weights:
419-
compressed_bits = submodule.quantization_scheme.weights.num_bits
420-
421-
num_weights = parameter.numel()
422-
total_compressed += compressed_bits * num_weights
423-
total_uncompressed += uncompressed_bits * num_weights
424-
425-
return total_uncompressed / total_compressed
426-
427-
428404
def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
429405
"""
430406
Check whether the QuantizationScheme targets the kv cache.

src/compressed_tensors/transform/factory/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919
import torch.nn.utils.parametrize as P
20+
from compressed_tensors import InternalModule
2021
from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils
2122
from compressed_tensors.registry.registry import RegistryMixin, T
2223
from compressed_tensors.transform import (
@@ -144,7 +145,7 @@ def output_hook(_, _input, output):
144145
# to support saving in the frozen state
145146

146147

147-
class TransformBase(Module, ABC):
148+
class TransformBase(InternalModule, ABC):
148149
"""
149150
Represents the application of a transform accord to TransformArgs
150151
"""

src/compressed_tensors/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# flake8: noqa
1515

1616
from .helpers import *
17+
from .internal import *
1718
from .offload import *
1819
from .permutations_24 import *
1920
from .permute import *
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
17+
18+
__all__ = ["InternalModule"]
19+
20+
21+
class InternalModule(torch.nn.Module):
22+
"""
23+
Abstract base class for modules which are not a part of the the model definition.
24+
`torch.nn.Module`s which inherit from this class will not be targeted by configs
25+
26+
This is typically used to skip apply configs to `Observers` and `Transforms`
27+
"""
28+
29+
pass

tests/test_quantization/lifecycle/test_apply.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
expand_target_names,
3232
is_target,
3333
)
34-
from compressed_tensors.quantization.utils import iter_named_leaf_modules
3534
from tests.testing_utils import requires_accelerate
3635
from transformers import AutoModelForCausalLM
3736

@@ -98,7 +97,7 @@ def test_target_prioritization(mock_frozen):
9897
apply_quantization_config(model, config)
9998
mock_frozen(model)
10099

101-
for name, module in iter_named_leaf_modules(model):
100+
for name, module in model.named_modules():
102101
if name == "model.layers.0.mlp.down_proj":
103102
assert module.quantization_scheme.weights.num_bits == 2
104103
elif re.match(".*down_proj", name):

0 commit comments

Comments
 (0)