diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 700c1769..38c1d9f5 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -42,10 +42,7 @@ load_pretrained_quantization_parameters, ) from compressed_tensors.quantization.lifecycle import expand_target_names -from compressed_tensors.quantization.utils import ( - is_module_quantized, - iter_named_leaf_modules, -) +from compressed_tensors.quantization.utils import is_module_quantized from compressed_tensors.utils import ( align_module_device, delete_offload_parameter, @@ -747,7 +744,7 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]: """ return { fix_fsdp_module_name(name): module.quantization_scheme - for name, module in iter_named_leaf_modules(model) + for name, module in model.named_modules() if is_module_quantized(module) } diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index c8dbeced..7afd2aba 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -38,8 +38,6 @@ KV_CACHE_TARGETS, infer_quantization_status, is_kv_cache_quant_scheme, - iter_named_leaf_modules, - iter_named_quantizable_modules, ) from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module from compressed_tensors.utils.offload import update_parameter_data @@ -87,7 +85,7 @@ def load_pretrained_quantization_parameters( model_path = get_safetensors_folder(model_name_or_path) mapping = get_quantization_parameter_to_path_mapping(model_path) - for name, submodule in iter_named_leaf_modules(model): + for name, submodule in model.named_modules(): if not is_module_quantized(submodule): continue if submodule.quantization_scheme.input_activations is not None: @@ -152,11 +150,7 @@ def apply_quantization_config( # list of submodules to ignore ignored_submodules = defaultdict(list) # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in iter_named_quantizable_modules( - model, - include_children=True, - include_attn=True, - ): # child modules and attention modules + for name, submodule in model.named_modules(): # potentially fix module name to remove FSDP wrapper prefix name = fix_fsdp_module_name(name) if matches := find_name_or_class_matches(name, submodule, config.ignore): @@ -287,7 +281,7 @@ def expand_target_names( """ return { name - for name, module in iter_named_leaf_modules(model) + for name, module in model.named_modules() if is_target(name, module, targets, ignore) } @@ -328,6 +322,11 @@ def find_name_or_class_matches( 2. matches on regex patterns 3. matches on module names """ + from compressed_tensors import InternalModule + + if isinstance(module, InternalModule): + return [] + targets = sorted(targets, key=lambda x: ("re:" in x, x)) if isinstance(targets, Iterable): matches = _find_matches(name, targets) + _find_matches( diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 47f43f5f..36ed1982 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -22,9 +22,7 @@ preset_name_to_scheme, ) from compressed_tensors.quantization.utils import ( - calculate_compression_ratio, is_module_quantized, - iter_named_quantizable_modules, module_type, parse_out_kv_cache_args, ) @@ -177,9 +175,7 @@ def from_pretrained( quantization_status = None ignore = {} quantization_type_names = set() - for name, submodule in iter_named_quantizable_modules( - model, include_children=True, include_attn=True - ): + for name, submodule in model.named_modules(): layer_type = module_type(submodule) if not is_module_quantized(submodule): if layer_type not in ignore: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 5d855f75..b6d81009 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -26,6 +26,7 @@ QuantizationType, ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.utils import deprecated from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module from tqdm import tqdm @@ -36,7 +37,6 @@ "is_module_quantized", "is_model_quantized", "module_type", - "calculate_compression_ratio", "get_torch_bit_depth", "can_quantize", "parse_out_kv_cache_args", @@ -276,12 +276,7 @@ def is_model_quantized(model: Module) -> bool: :param model: pytorch model :return: True if model is quantized, False otherwise """ - - for _, submodule in iter_named_leaf_modules(model): - if is_module_quantized(submodule): - return True - - return False + return any(is_module_quantized(submodule) for submodule in model.modules()) def module_type(module: Module) -> str: @@ -294,6 +289,11 @@ def module_type(module: Module) -> str: return type(module).__name__ +@deprecated( + message="This function will be removed in a future release. " + "Please use `model.named_modules()` and filter by " + "compressed_tensors.InternalModule if neceessary" +) def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]: """ Yields modules that do not have any submodules except observers. The observers @@ -320,6 +320,11 @@ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None yield name, submodule +@deprecated( + message="This function will be removed in a future release. " + "Please use `model.named_modules()` and filter by " + "compressed_tensors.InternalModule if neceessary" +) def iter_named_quantizable_modules( model: Module, include_children: bool = True, @@ -330,7 +335,6 @@ def iter_named_quantizable_modules( Yield name and submodule of - leaf modules, set by include_children - attention modyles, set by include_attn - :param model: model to get leaf modules of :param include_children: flag to get the leaf modules :param inlcude_attn: flag to get the attention modules @@ -397,34 +401,6 @@ def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool: return bit_depth > quant_args.num_bits -def calculate_compression_ratio(model: Module) -> float: - """ - Calculates the quantization compression ratio of a pytorch model, based on the - number of bits needed to represent the total weights in compressed form. Does not - take into account activation quantizatons. - - :param model: pytorch module to calculate compression ratio for - :return: compression ratio of the whole model - """ - total_compressed = 0.0 - total_uncompressed = 0.0 - for name, submodule in tqdm( - iter_named_leaf_modules(model), - desc="Calculating quantization compression ratio", - ): - for parameter in model.parameters(): - uncompressed_bits = get_torch_bit_depth(parameter) - compressed_bits = uncompressed_bits - if is_module_quantized(submodule) and submodule.quantization_scheme.weights: - compressed_bits = submodule.quantization_scheme.weights.num_bits - - num_weights = parameter.numel() - total_compressed += compressed_bits * num_weights - total_uncompressed += uncompressed_bits * num_weights - - return total_uncompressed / total_compressed - - def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool: """ Check whether the QuantizationScheme targets the kv cache. diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 7448c604..62b9ddbd 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -17,6 +17,7 @@ import torch import torch.nn.utils.parametrize as P +from compressed_tensors import InternalModule from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( @@ -144,7 +145,7 @@ def output_hook(_, _input, output): # to support saving in the frozen state -class TransformBase(Module, ABC): +class TransformBase(InternalModule, ABC): """ Represents the application of a transform accord to TransformArgs """ diff --git a/src/compressed_tensors/utils/__init__.py b/src/compressed_tensors/utils/__init__.py index 976d55f7..e0b60557 100644 --- a/src/compressed_tensors/utils/__init__.py +++ b/src/compressed_tensors/utils/__init__.py @@ -14,6 +14,7 @@ # flake8: noqa from .helpers import * +from .internal import * from .offload import * from .permutations_24 import * from .permute import * diff --git a/src/compressed_tensors/utils/internal.py b/src/compressed_tensors/utils/internal.py new file mode 100644 index 00000000..6297e5b0 --- /dev/null +++ b/src/compressed_tensors/utils/internal.py @@ -0,0 +1,29 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +__all__ = ["InternalModule"] + + +class InternalModule(torch.nn.Module): + """ + Abstract base class for modules which are not a part of the the model definition. + `torch.nn.Module`s which inherit from this class will not be targeted by configs + + This is typically used to skip apply configs to `Observers` and `Transforms` + """ + + pass diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 9d699651..63a9a588 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -31,7 +31,6 @@ expand_target_names, is_target, ) -from compressed_tensors.quantization.utils import iter_named_leaf_modules from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM @@ -98,7 +97,7 @@ def test_target_prioritization(mock_frozen): apply_quantization_config(model, config) mock_frozen(model) - for name, module in iter_named_leaf_modules(model): + for name, module in model.named_modules(): if name == "model.layers.0.mlp.down_proj": assert module.quantization_scheme.weights.num_bits == 2 elif re.match(".*down_proj", name):