From 4571b5cccda666d28c42330735a74265a0b738cb Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 7 Jul 2025 11:56:58 -0400 Subject: [PATCH 1/7] remove iter helper functions Signed-off-by: Kyle Sayers --- .../model_compressors/model_compressor.py | 7 +- .../quantization/lifecycle/apply.py | 12 +- .../quantization/quant_config.py | 6 +- .../quantization/utils/helpers.py | 106 +----------------- .../test_quantization/lifecycle/test_apply.py | 3 +- 5 files changed, 8 insertions(+), 126 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 700c1769..38c1d9f5 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -42,10 +42,7 @@ load_pretrained_quantization_parameters, ) from compressed_tensors.quantization.lifecycle import expand_target_names -from compressed_tensors.quantization.utils import ( - is_module_quantized, - iter_named_leaf_modules, -) +from compressed_tensors.quantization.utils import is_module_quantized from compressed_tensors.utils import ( align_module_device, delete_offload_parameter, @@ -747,7 +744,7 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]: """ return { fix_fsdp_module_name(name): module.quantization_scheme - for name, module in iter_named_leaf_modules(model) + for name, module in model.named_modules() if is_module_quantized(module) } diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index c8dbeced..64794870 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -38,8 +38,6 @@ KV_CACHE_TARGETS, infer_quantization_status, is_kv_cache_quant_scheme, - iter_named_leaf_modules, - iter_named_quantizable_modules, ) from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module from compressed_tensors.utils.offload import update_parameter_data @@ -87,7 +85,7 @@ def load_pretrained_quantization_parameters( model_path = get_safetensors_folder(model_name_or_path) mapping = get_quantization_parameter_to_path_mapping(model_path) - for name, submodule in iter_named_leaf_modules(model): + for name, submodule in model.named_modules(): if not is_module_quantized(submodule): continue if submodule.quantization_scheme.input_activations is not None: @@ -152,11 +150,7 @@ def apply_quantization_config( # list of submodules to ignore ignored_submodules = defaultdict(list) # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in iter_named_quantizable_modules( - model, - include_children=True, - include_attn=True, - ): # child modules and attention modules + for name, submodule in model.named_modules(): # potentially fix module name to remove FSDP wrapper prefix name = fix_fsdp_module_name(name) if matches := find_name_or_class_matches(name, submodule, config.ignore): @@ -287,7 +281,7 @@ def expand_target_names( """ return { name - for name, module in iter_named_leaf_modules(model) + for name, module in model.named_modules() if is_target(name, module, targets, ignore) } diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 47f43f5f..36ed1982 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -22,9 +22,7 @@ preset_name_to_scheme, ) from compressed_tensors.quantization.utils import ( - calculate_compression_ratio, is_module_quantized, - iter_named_quantizable_modules, module_type, parse_out_kv_cache_args, ) @@ -177,9 +175,7 @@ def from_pretrained( quantization_status = None ignore = {} quantization_type_names = set() - for name, submodule in iter_named_quantizable_modules( - model, include_children=True, include_attn=True - ): + for name, submodule in model.named_modules(): layer_type = module_type(submodule) if not is_module_quantized(submodule): if layer_type not in ignore: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 5d855f75..8d9a6a86 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -36,14 +36,11 @@ "is_module_quantized", "is_model_quantized", "module_type", - "calculate_compression_ratio", "get_torch_bit_depth", "can_quantize", "parse_out_kv_cache_args", "KV_CACHE_TARGETS", "is_kv_cache_quant_scheme", - "iter_named_leaf_modules", - "iter_named_quantizable_modules", "compute_dynamic_scales_and_zp", "calculate_range", "calculate_qparams", @@ -276,12 +273,7 @@ def is_model_quantized(model: Module) -> bool: :param model: pytorch model :return: True if model is quantized, False otherwise """ - - for _, submodule in iter_named_leaf_modules(model): - if is_module_quantized(submodule): - return True - - return False + return any(is_module_quantized(submodule) for submodule in model.modules()) def module_type(module: Module) -> str: @@ -294,74 +286,6 @@ def module_type(module: Module) -> str: return type(module).__name__ -def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]: - """ - Yields modules that do not have any submodules except observers. The observers - themselves are not yielded - :param model: model to get leaf modules of - :returns: generator tuple of (name, leaf_submodule) - """ - for name, submodule in model.named_modules(): - children = list(submodule.children()) - # TODO: verify if an observer would ever be attached in this case/remove check - if len(children) == 0 and "observer" in name: - yield name, submodule - else: - if len(children) > 0: - named_children, children = zip(*list(submodule.named_children())) - has_non_observer_children = False - for i in range(len(children)): - child_name = named_children[i] - - if "observer" not in child_name: - has_non_observer_children = True - - if not has_non_observer_children: - yield name, submodule - - -def iter_named_quantizable_modules( - model: Module, - include_children: bool = True, - include_attn: bool = False, - include_mlp: bool = False, -) -> Generator[Tuple[str, Module], None, None]: - """ - Yield name and submodule of - - leaf modules, set by include_children - - attention modyles, set by include_attn - - :param model: model to get leaf modules of - :param include_children: flag to get the leaf modules - :param inlcude_attn: flag to get the attention modules - :returns: generator tuple of (name, submodule) - """ - for name, submodule in model.named_modules(): - # TODO: verify if an observer would ever be attached in this case/remove check - if include_children: - children = list(submodule.children()) - if len(children) == 0 and "observer" not in name: - yield name, submodule - else: - if len(children) > 0: - named_children, children = zip(*list(submodule.named_children())) - has_non_observer_children = False - for i in range(len(children)): - child_name = named_children[i] - - if "observer" not in child_name: - has_non_observer_children = True - - if not has_non_observer_children: - yield name, submodule - if include_attn: - if name.endswith("self_attn"): - yield name, submodule - if include_mlp: - if name.endswith("mlp"): - yield name, submodule - - def get_torch_bit_depth(value: torch.Tensor) -> int: """ Determine the number of bits used to represent the dtype of a tensor @@ -397,34 +321,6 @@ def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool: return bit_depth > quant_args.num_bits -def calculate_compression_ratio(model: Module) -> float: - """ - Calculates the quantization compression ratio of a pytorch model, based on the - number of bits needed to represent the total weights in compressed form. Does not - take into account activation quantizatons. - - :param model: pytorch module to calculate compression ratio for - :return: compression ratio of the whole model - """ - total_compressed = 0.0 - total_uncompressed = 0.0 - for name, submodule in tqdm( - iter_named_leaf_modules(model), - desc="Calculating quantization compression ratio", - ): - for parameter in model.parameters(): - uncompressed_bits = get_torch_bit_depth(parameter) - compressed_bits = uncompressed_bits - if is_module_quantized(submodule) and submodule.quantization_scheme.weights: - compressed_bits = submodule.quantization_scheme.weights.num_bits - - num_weights = parameter.numel() - total_compressed += compressed_bits * num_weights - total_uncompressed += uncompressed_bits * num_weights - - return total_uncompressed / total_compressed - - def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool: """ Check whether the QuantizationScheme targets the kv cache. diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 9d699651..63a9a588 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -31,7 +31,6 @@ expand_target_names, is_target, ) -from compressed_tensors.quantization.utils import iter_named_leaf_modules from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM @@ -98,7 +97,7 @@ def test_target_prioritization(mock_frozen): apply_quantization_config(model, config) mock_frozen(model) - for name, module in iter_named_leaf_modules(model): + for name, module in model.named_modules(): if name == "model.layers.0.mlp.down_proj": assert module.quantization_scheme.weights.num_bits == 2 elif re.match(".*down_proj", name): From 10bd9d5580aaa4baaff6641e8d6324123023d966 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 7 Jul 2025 12:28:54 -0400 Subject: [PATCH 2/7] use internal module Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/apply.py | 4 +++ .../transform/factory/base.py | 3 ++- src/compressed_tensors/utils/__init__.py | 1 + src/compressed_tensors/utils/internal.py | 26 +++++++++++++++++++ 4 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 src/compressed_tensors/utils/internal.py diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 64794870..f62b37a1 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -21,6 +21,7 @@ from typing import Set, Union import torch +from compressed_tensors import InternalModule from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.lifecycle.compressed import ( compress_quantized_weights, @@ -322,6 +323,9 @@ def find_name_or_class_matches( 2. matches on regex patterns 3. matches on module names """ + if isinstance(module, InternalModule): + return [] + targets = sorted(targets, key=lambda x: ("re:" in x, x)) if isinstance(targets, Iterable): matches = _find_matches(name, targets) + _find_matches( diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 7448c604..62b9ddbd 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -17,6 +17,7 @@ import torch import torch.nn.utils.parametrize as P +from compressed_tensors import InternalModule from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( @@ -144,7 +145,7 @@ def output_hook(_, _input, output): # to support saving in the frozen state -class TransformBase(Module, ABC): +class TransformBase(InternalModule, ABC): """ Represents the application of a transform accord to TransformArgs """ diff --git a/src/compressed_tensors/utils/__init__.py b/src/compressed_tensors/utils/__init__.py index 976d55f7..e0b60557 100644 --- a/src/compressed_tensors/utils/__init__.py +++ b/src/compressed_tensors/utils/__init__.py @@ -14,6 +14,7 @@ # flake8: noqa from .helpers import * +from .internal import * from .offload import * from .permutations_24 import * from .permute import * diff --git a/src/compressed_tensors/utils/internal.py b/src/compressed_tensors/utils/internal.py new file mode 100644 index 00000000..41898377 --- /dev/null +++ b/src/compressed_tensors/utils/internal.py @@ -0,0 +1,26 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +__all__ = ["InternalModule"] + + +class InternalModule(torch.nn.Module): + """ + Abstract base class for modules which are not a part of the the model definition. + `torch.nn.Module`s which inherit from this class will not be targeted by configs + """ + pass From 44b678aee71964e94cd0e890a9195c307d64343e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 7 Jul 2025 12:29:48 -0400 Subject: [PATCH 3/7] add docstring Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/internal.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compressed_tensors/utils/internal.py b/src/compressed_tensors/utils/internal.py index 41898377..6297e5b0 100644 --- a/src/compressed_tensors/utils/internal.py +++ b/src/compressed_tensors/utils/internal.py @@ -22,5 +22,8 @@ class InternalModule(torch.nn.Module): """ Abstract base class for modules which are not a part of the the model definition. `torch.nn.Module`s which inherit from this class will not be targeted by configs + + This is typically used to skip apply configs to `Observers` and `Transforms` """ + pass From 591abc7a763817aba8946182b1e49d48db17852f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 7 Jul 2025 12:59:22 -0400 Subject: [PATCH 4/7] fix import cycle Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/apply.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index f62b37a1..7afd2aba 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -21,7 +21,6 @@ from typing import Set, Union import torch -from compressed_tensors import InternalModule from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.lifecycle.compressed import ( compress_quantized_weights, @@ -323,6 +322,8 @@ def find_name_or_class_matches( 2. matches on regex patterns 3. matches on module names """ + from compressed_tensors import InternalModule + if isinstance(module, InternalModule): return [] From b348966b7f37cc741f26b0d9cc02ca14c60b8555 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 7 Jul 2025 13:11:15 -0400 Subject: [PATCH 5/7] keep as deprecated Signed-off-by: Kyle Sayers --- .../quantization/utils/helpers.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 8d9a6a86..b6d81009 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -26,6 +26,7 @@ QuantizationType, ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.utils import deprecated from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module from tqdm import tqdm @@ -41,6 +42,8 @@ "parse_out_kv_cache_args", "KV_CACHE_TARGETS", "is_kv_cache_quant_scheme", + "iter_named_leaf_modules", + "iter_named_quantizable_modules", "compute_dynamic_scales_and_zp", "calculate_range", "calculate_qparams", @@ -286,6 +289,83 @@ def module_type(module: Module) -> str: return type(module).__name__ +@deprecated( + message="This function will be removed in a future release. " + "Please use `model.named_modules()` and filter by " + "compressed_tensors.InternalModule if neceessary" +) +def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]: + """ + Yields modules that do not have any submodules except observers. The observers + themselves are not yielded + :param model: model to get leaf modules of + :returns: generator tuple of (name, leaf_submodule) + """ + for name, submodule in model.named_modules(): + children = list(submodule.children()) + # TODO: verify if an observer would ever be attached in this case/remove check + if len(children) == 0 and "observer" in name: + yield name, submodule + else: + if len(children) > 0: + named_children, children = zip(*list(submodule.named_children())) + has_non_observer_children = False + for i in range(len(children)): + child_name = named_children[i] + + if "observer" not in child_name: + has_non_observer_children = True + + if not has_non_observer_children: + yield name, submodule + + +@deprecated( + message="This function will be removed in a future release. " + "Please use `model.named_modules()` and filter by " + "compressed_tensors.InternalModule if neceessary" +) +def iter_named_quantizable_modules( + model: Module, + include_children: bool = True, + include_attn: bool = False, + include_mlp: bool = False, +) -> Generator[Tuple[str, Module], None, None]: + """ + Yield name and submodule of + - leaf modules, set by include_children + - attention modyles, set by include_attn + :param model: model to get leaf modules of + :param include_children: flag to get the leaf modules + :param inlcude_attn: flag to get the attention modules + :returns: generator tuple of (name, submodule) + """ + for name, submodule in model.named_modules(): + # TODO: verify if an observer would ever be attached in this case/remove check + if include_children: + children = list(submodule.children()) + if len(children) == 0 and "observer" not in name: + yield name, submodule + else: + if len(children) > 0: + named_children, children = zip(*list(submodule.named_children())) + has_non_observer_children = False + for i in range(len(children)): + child_name = named_children[i] + + if "observer" not in child_name: + has_non_observer_children = True + + if not has_non_observer_children: + yield name, submodule + if include_attn: + if name.endswith("self_attn"): + yield name, submodule + if include_mlp: + if name.endswith("mlp"): + yield name, submodule + + def get_torch_bit_depth(value: torch.Tensor) -> int: """ Determine the number of bits used to represent the dtype of a tensor From 9b23a62cad047f16e6ff3a3cca96870a878c087e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 8 Jul 2025 10:30:00 -0400 Subject: [PATCH 6/7] rename to Untargetable Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/apply.py | 4 ++-- src/compressed_tensors/quantization/utils/helpers.py | 4 ++-- src/compressed_tensors/transform/factory/base.py | 4 ++-- src/compressed_tensors/utils/__init__.py | 2 +- src/compressed_tensors/utils/{internal.py => untargetable.py} | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) rename src/compressed_tensors/utils/{internal.py => untargetable.py} (92%) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 7afd2aba..11b1d1a8 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -322,9 +322,9 @@ def find_name_or_class_matches( 2. matches on regex patterns 3. matches on module names """ - from compressed_tensors import InternalModule + from compressed_tensors import UntargetableModule - if isinstance(module, InternalModule): + if isinstance(module, UntargetableModule): return [] targets = sorted(targets, key=lambda x: ("re:" in x, x)) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index b6d81009..6da40d3e 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -292,7 +292,7 @@ def module_type(module: Module) -> str: @deprecated( message="This function will be removed in a future release. " "Please use `model.named_modules()` and filter by " - "compressed_tensors.InternalModule if neceessary" + "compressed_tensors.UntargetableModule if neceessary" ) def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]: """ @@ -323,7 +323,7 @@ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None @deprecated( message="This function will be removed in a future release. " "Please use `model.named_modules()` and filter by " - "compressed_tensors.InternalModule if neceessary" + "compressed_tensors.UntargetableModule if neceessary" ) def iter_named_quantizable_modules( model: Module, diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 62b9ddbd..0898db21 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -17,7 +17,7 @@ import torch import torch.nn.utils.parametrize as P -from compressed_tensors import InternalModule +from compressed_tensors import UntargetableModule from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( @@ -145,7 +145,7 @@ def output_hook(_, _input, output): # to support saving in the frozen state -class TransformBase(InternalModule, ABC): +class TransformBase(UntargetableModule, ABC): """ Represents the application of a transform accord to TransformArgs """ diff --git a/src/compressed_tensors/utils/__init__.py b/src/compressed_tensors/utils/__init__.py index e0b60557..ee9f1bce 100644 --- a/src/compressed_tensors/utils/__init__.py +++ b/src/compressed_tensors/utils/__init__.py @@ -14,9 +14,9 @@ # flake8: noqa from .helpers import * -from .internal import * from .offload import * from .permutations_24 import * from .permute import * from .safetensors_load import * from .semi_structured_conversions import * +from .untargetable import * diff --git a/src/compressed_tensors/utils/internal.py b/src/compressed_tensors/utils/untargetable.py similarity index 92% rename from src/compressed_tensors/utils/internal.py rename to src/compressed_tensors/utils/untargetable.py index 6297e5b0..67b344da 100644 --- a/src/compressed_tensors/utils/internal.py +++ b/src/compressed_tensors/utils/untargetable.py @@ -15,10 +15,10 @@ import torch -__all__ = ["InternalModule"] +__all__ = ["UntargetableModule"] -class InternalModule(torch.nn.Module): +class UntargetableModule(torch.nn.Module): """ Abstract base class for modules which are not a part of the the model definition. `torch.nn.Module`s which inherit from this class will not be targeted by configs From d9007bb804b2cb7b0cc6f77cf64c4040c88dd7f6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 8 Jul 2025 13:35:30 -0400 Subject: [PATCH 7/7] Revert "rename to Untargetable" This reverts commit 9b23a62cad047f16e6ff3a3cca96870a878c087e. --- src/compressed_tensors/quantization/lifecycle/apply.py | 4 ++-- src/compressed_tensors/quantization/utils/helpers.py | 4 ++-- src/compressed_tensors/transform/factory/base.py | 4 ++-- src/compressed_tensors/utils/__init__.py | 2 +- src/compressed_tensors/utils/{untargetable.py => internal.py} | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) rename src/compressed_tensors/utils/{untargetable.py => internal.py} (92%) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 11b1d1a8..7afd2aba 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -322,9 +322,9 @@ def find_name_or_class_matches( 2. matches on regex patterns 3. matches on module names """ - from compressed_tensors import UntargetableModule + from compressed_tensors import InternalModule - if isinstance(module, UntargetableModule): + if isinstance(module, InternalModule): return [] targets = sorted(targets, key=lambda x: ("re:" in x, x)) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 6da40d3e..b6d81009 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -292,7 +292,7 @@ def module_type(module: Module) -> str: @deprecated( message="This function will be removed in a future release. " "Please use `model.named_modules()` and filter by " - "compressed_tensors.UntargetableModule if neceessary" + "compressed_tensors.InternalModule if neceessary" ) def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]: """ @@ -323,7 +323,7 @@ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None @deprecated( message="This function will be removed in a future release. " "Please use `model.named_modules()` and filter by " - "compressed_tensors.UntargetableModule if neceessary" + "compressed_tensors.InternalModule if neceessary" ) def iter_named_quantizable_modules( model: Module, diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 0898db21..62b9ddbd 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -17,7 +17,7 @@ import torch import torch.nn.utils.parametrize as P -from compressed_tensors import UntargetableModule +from compressed_tensors import InternalModule from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( @@ -145,7 +145,7 @@ def output_hook(_, _input, output): # to support saving in the frozen state -class TransformBase(UntargetableModule, ABC): +class TransformBase(InternalModule, ABC): """ Represents the application of a transform accord to TransformArgs """ diff --git a/src/compressed_tensors/utils/__init__.py b/src/compressed_tensors/utils/__init__.py index ee9f1bce..e0b60557 100644 --- a/src/compressed_tensors/utils/__init__.py +++ b/src/compressed_tensors/utils/__init__.py @@ -14,9 +14,9 @@ # flake8: noqa from .helpers import * +from .internal import * from .offload import * from .permutations_24 import * from .permute import * from .safetensors_load import * from .semi_structured_conversions import * -from .untargetable import * diff --git a/src/compressed_tensors/utils/untargetable.py b/src/compressed_tensors/utils/internal.py similarity index 92% rename from src/compressed_tensors/utils/untargetable.py rename to src/compressed_tensors/utils/internal.py index 67b344da..6297e5b0 100644 --- a/src/compressed_tensors/utils/untargetable.py +++ b/src/compressed_tensors/utils/internal.py @@ -15,10 +15,10 @@ import torch -__all__ = ["UntargetableModule"] +__all__ = ["InternalModule"] -class UntargetableModule(torch.nn.Module): +class InternalModule(torch.nn.Module): """ Abstract base class for modules which are not a part of the the model definition. `torch.nn.Module`s which inherit from this class will not be targeted by configs