Skip to content

Commit 853ffcf

Browse files
committed
Merge remote-tracking branch 'origin' into kylesayrs/transform_apply
2 parents bbf9533 + 40fa3c5 commit 853ffcf

File tree

10 files changed

+105
-54
lines changed

10 files changed

+105
-54
lines changed

.github/actions/test/action.yml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ inputs:
77
suitename:
88
description: "test suite name"
99
required: true
10+
code_coverage:
11+
description: whether to collect code coverage metrics during test run
12+
type: boolean
13+
default: false
1014
outputs:
1115
status:
1216
description: "final status from test"
@@ -44,9 +48,37 @@ runs:
4448
run: |
4549
source ${{ inputs.venv }}/bin/activate
4650
rm -rf src
51+
52+
if [[ "${ENABLE_COVERAGE}" == "true" ]]; then
53+
echo "::group::Installing code coverage requirements via pip"
54+
pip install bashlex https://github.com/neuralmagic/pytest-nm-releng/archive/v0.4.0.tar.gz
55+
pip install coverage pytest-cov
56+
57+
# Adding Code coverage to the tests
58+
nmre-generate-coverage-flags --package "compressed_tensors" --output-file ".coverage_flags.sh"
59+
source .coverage_flags.sh
60+
echo "::endgroup::"
61+
fi
62+
63+
echo "::group::running tests"
64+
echo "PYTEST_ADDOPTS set to: ${PYTEST_ADDOPTS}"
65+
4766
SUCCESS=0
4867
pytest tests --junitxml=test-results/report.xml -o junit_suite_name="${{ inputs.suitename }}" || SUCCESS=$?
4968
echo "status=${SUCCESS}" >> "$GITHUB_OUTPUT"
69+
echo "::endgroup::"
70+
71+
if [[ "${ENABLE_COVERAGE}" == "true" ]]; then
72+
echo "::group::consolidating coverage reports"
73+
mkdir -p coverage-results
74+
mv .coverage coverage-results/ || echo ".coverage file not found"
75+
mv coverage-html coverage-results/ || echo "coverage-html folder not found"
76+
mv coverage.json coverage-results/ || echo "coverage.json file not found"
77+
echo "::endgroup::"
78+
fi
79+
5080
deactivate
5181
exit ${SUCCESS}
5282
shell: bash
83+
env:
84+
ENABLE_COVERAGE: ${{ inputs.code_coverage || false }}

.github/workflows/test.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ on:
2525
run_id:
2626
description: run id of the BUILD job that generated the assets
2727
type: string
28+
code_coverage:
29+
description: whether to collect code coverage metrics during test run
30+
type: boolean
31+
default: false
2832

2933
# makes workflow manually callable
3034
workflow_dispatch:
@@ -51,6 +55,10 @@ on:
5155
run_id:
5256
description: run id of the BUILD job that generated the assets
5357
type: string
58+
code_coverage:
59+
description: whether to collect code coverage metrics during test run
60+
type: boolean
61+
default: false
5462

5563
jobs:
5664

@@ -124,6 +132,7 @@ jobs:
124132
with:
125133
venv: ${{ steps.create_venv.outputs.penv }}
126134
suitename: test-${{ inputs.python }}-${{ inputs.test_label }}
135+
code_coverage: ${{ inputs.code_coverage }}
127136

128137
- name: summary
129138
uses: neuralmagic/nm-actions/actions/summary-test@v1.13.0
@@ -146,3 +155,11 @@ jobs:
146155
name: report-${{ inputs.test_label }}.xml
147156
path: test-results/report.xml
148157
retention-days: 5
158+
159+
- name: upload coverage report
160+
uses: actions/upload-artifact@v4
161+
if: (success() || failure()) && inputs.code_coverage
162+
with:
163+
name: coverage-results
164+
path: coverage-results/*
165+
retention-days: 5

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,7 @@
4242
load_pretrained_quantization_parameters,
4343
)
4444
from compressed_tensors.quantization.lifecycle import expand_target_names
45-
from compressed_tensors.quantization.utils import (
46-
is_module_quantized,
47-
iter_named_leaf_modules,
48-
)
45+
from compressed_tensors.quantization.utils import is_module_quantized
4946
from compressed_tensors.utils import (
5047
align_module_device,
5148
delete_offload_parameter,
@@ -747,7 +744,7 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
747744
"""
748745
return {
749746
fix_fsdp_module_name(name): module.quantization_scheme
750-
for name, module in iter_named_leaf_modules(model)
747+
for name, module in model.named_modules()
751748
if is_module_quantized(module)
752749
}
753750

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@
3838
KV_CACHE_TARGETS,
3939
infer_quantization_status,
4040
is_kv_cache_quant_scheme,
41-
iter_named_leaf_modules,
42-
iter_named_quantizable_modules,
4341
)
4442
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
4543
from compressed_tensors.utils.offload import update_parameter_data
@@ -87,7 +85,7 @@ def load_pretrained_quantization_parameters(
8785
model_path = get_safetensors_folder(model_name_or_path)
8886
mapping = get_quantization_parameter_to_path_mapping(model_path)
8987

90-
for name, submodule in iter_named_leaf_modules(model):
88+
for name, submodule in model.named_modules():
9189
if not is_module_quantized(submodule):
9290
continue
9391
if submodule.quantization_scheme.input_activations is not None:
@@ -152,7 +150,7 @@ def apply_quantization_config(
152150
# list of submodules to ignore
153151
ignored_submodules = defaultdict(list)
154152
# mark appropriate layers for quantization by setting their quantization schemes
155-
for name, submodule in model.named_modules(): # child modules and attention modules
153+
for name, submodule in model.named_modules():
156154
# potentially fix module name to remove FSDP wrapper prefix
157155
name = fix_fsdp_module_name(name)
158156
if matches := find_name_or_class_matches(name, submodule, config.ignore):
@@ -283,7 +281,7 @@ def expand_target_names(
283281
"""
284282
return {
285283
name
286-
for name, module in iter_named_leaf_modules(model)
284+
for name, module in model.named_modules()
287285
if is_target(name, module, targets, ignore)
288286
}
289287

@@ -324,6 +322,11 @@ def find_name_or_class_matches(
324322
2. matches on regex patterns
325323
3. matches on module names
326324
"""
325+
from compressed_tensors import InternalModule
326+
327+
if isinstance(module, InternalModule):
328+
return []
329+
327330
targets = sorted(targets, key=lambda x: ("re:" in x, x))
328331
if isinstance(targets, Iterable):
329332
matches = _find_matches(name, targets) + _find_matches(

src/compressed_tensors/quantization/quant_config.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
preset_name_to_scheme,
2323
)
2424
from compressed_tensors.quantization.utils import (
25-
calculate_compression_ratio,
2625
is_module_quantized,
27-
iter_named_quantizable_modules,
2826
module_type,
2927
parse_out_kv_cache_args,
3028
)
@@ -177,9 +175,7 @@ def from_pretrained(
177175
quantization_status = None
178176
ignore = {}
179177
quantization_type_names = set()
180-
for name, submodule in iter_named_quantizable_modules(
181-
model, include_children=True, include_attn=True
182-
):
178+
for name, submodule in model.named_modules():
183179
layer_type = module_type(submodule)
184180
if not is_module_quantized(submodule):
185181
if layer_type not in ignore:

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
QuantizationType,
2727
)
2828
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
29+
from compressed_tensors.utils import deprecated
2930
from torch import FloatTensor, IntTensor, Tensor
3031
from torch.nn import Module
3132
from tqdm import tqdm
@@ -36,7 +37,6 @@
3637
"is_module_quantized",
3738
"is_model_quantized",
3839
"module_type",
39-
"calculate_compression_ratio",
4040
"get_torch_bit_depth",
4141
"can_quantize",
4242
"parse_out_kv_cache_args",
@@ -276,12 +276,7 @@ def is_model_quantized(model: Module) -> bool:
276276
:param model: pytorch model
277277
:return: True if model is quantized, False otherwise
278278
"""
279-
280-
for _, submodule in iter_named_leaf_modules(model):
281-
if is_module_quantized(submodule):
282-
return True
283-
284-
return False
279+
return any(is_module_quantized(submodule) for submodule in model.modules())
285280

286281

287282
def module_type(module: Module) -> str:
@@ -294,6 +289,11 @@ def module_type(module: Module) -> str:
294289
return type(module).__name__
295290

296291

292+
@deprecated(
293+
message="This function will be removed in a future release. "
294+
"Please use `model.named_modules()` and filter by "
295+
"compressed_tensors.InternalModule if neceessary"
296+
)
297297
def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]:
298298
"""
299299
Yields modules that do not have any submodules except observers. The observers
@@ -320,6 +320,11 @@ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None
320320
yield name, submodule
321321

322322

323+
@deprecated(
324+
message="This function will be removed in a future release. "
325+
"Please use `model.named_modules()` and filter by "
326+
"compressed_tensors.InternalModule if neceessary"
327+
)
323328
def iter_named_quantizable_modules(
324329
model: Module,
325330
include_children: bool = True,
@@ -330,7 +335,6 @@ def iter_named_quantizable_modules(
330335
Yield name and submodule of
331336
- leaf modules, set by include_children
332337
- attention modyles, set by include_attn
333-
334338
:param model: model to get leaf modules of
335339
:param include_children: flag to get the leaf modules
336340
:param inlcude_attn: flag to get the attention modules
@@ -397,34 +401,6 @@ def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool:
397401
return bit_depth > quant_args.num_bits
398402

399403

400-
def calculate_compression_ratio(model: Module) -> float:
401-
"""
402-
Calculates the quantization compression ratio of a pytorch model, based on the
403-
number of bits needed to represent the total weights in compressed form. Does not
404-
take into account activation quantizatons.
405-
406-
:param model: pytorch module to calculate compression ratio for
407-
:return: compression ratio of the whole model
408-
"""
409-
total_compressed = 0.0
410-
total_uncompressed = 0.0
411-
for name, submodule in tqdm(
412-
iter_named_leaf_modules(model),
413-
desc="Calculating quantization compression ratio",
414-
):
415-
for parameter in model.parameters():
416-
uncompressed_bits = get_torch_bit_depth(parameter)
417-
compressed_bits = uncompressed_bits
418-
if is_module_quantized(submodule) and submodule.quantization_scheme.weights:
419-
compressed_bits = submodule.quantization_scheme.weights.num_bits
420-
421-
num_weights = parameter.numel()
422-
total_compressed += compressed_bits * num_weights
423-
total_uncompressed += uncompressed_bits * num_weights
424-
425-
return total_uncompressed / total_compressed
426-
427-
428404
def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
429405
"""
430406
Check whether the QuantizationScheme targets the kv cache.

src/compressed_tensors/transform/factory/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919
import torch.nn.utils.parametrize as P
20+
from compressed_tensors import InternalModule
2021
from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils
2122
from compressed_tensors.registry.registry import RegistryMixin, T
2223
from compressed_tensors.transform import (
@@ -146,7 +147,7 @@ def output_hook(_, _input, output):
146147
raise NotImplementedError()
147148

148149

149-
class TransformBase(Module, ABC):
150+
class TransformBase(InternalModule, ABC):
150151
"""
151152
Represents the application of a transform accord to TransformArgs
152153
"""

src/compressed_tensors/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# flake8: noqa
1515

1616
from .helpers import *
17+
from .internal import *
1718
from .offload import *
1819
from .permutations_24 import *
1920
from .permute import *
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
17+
18+
__all__ = ["InternalModule"]
19+
20+
21+
class InternalModule(torch.nn.Module):
22+
"""
23+
Abstract base class for modules which are not a part of the the model definition.
24+
`torch.nn.Module`s which inherit from this class will not be targeted by configs
25+
26+
This is typically used to skip apply configs to `Observers` and `Transforms`
27+
"""
28+
29+
pass

tests/test_quantization/lifecycle/test_apply.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
expand_target_names,
3232
is_target,
3333
)
34-
from compressed_tensors.quantization.utils import iter_named_leaf_modules
3534
from tests.testing_utils import requires_accelerate
3635
from transformers import AutoModelForCausalLM
3736

@@ -98,7 +97,7 @@ def test_target_prioritization(mock_frozen):
9897
apply_quantization_config(model, config)
9998
mock_frozen(model)
10099

101-
for name, module in iter_named_leaf_modules(model):
100+
for name, module in model.named_modules():
102101
if name == "model.layers.0.mlp.down_proj":
103102
assert module.quantization_scheme.weights.num_bits == 2
104103
elif re.match(".*down_proj", name):

0 commit comments

Comments
 (0)